Compiling Models with Large Constant Arrays (#146)
* PoC works. * MNist works. * Clean up. * Fix test. * Make Linux work. * Use consistent symbol name. * Fix variable name. * Fix array addr access. * Bug fix. * Bug fix. * install before running e2e tests. * Fix build config. * Use sudo when installing. * Make embeddedDataLoader position independent. * Enable ResNet50. * Format code. * Format MainUtil. * Try not using sudo to install. * Supply runtime dir via environment variable. * Dump problematic operation. * Dump entire function. * Debug. * Dump input. * Dump constant op. * Debug. * Debug. * Debug. * Print to stderr. * take care of endianness. * Use endianness-aware execution session. * Fix ZLinux error. * Include warning when desired output endianness can't be deduced. * Remove debug code. * Remove debug code in shape inference. * Support binary-decoder for testing constants packing. * Support filename, move-to-file, elision-threshold configurations in constant packing pass for easy testing. * Add lit test, fix lit test type mismatch. * Add more consts packing tests. * Ensure intermediate files are properly cleaned up. * No need for constant elimination. * Link with threading libraries. * Remove debug code. * Format code. * More tests. * test nit. * Remove debug code. * Reduce hard-coded constants. * Use temporary and unique working directory for hosting model parameters. * Test if it works. * Try to find objcopy. * Rename symbols using objcopy. * Move sanitized name to linux section. * Use verbose mode for debugging. * Disambiguate pass constructor. * Fix symbol name. * Use Command API to build and execute commands. * Move linux to use Command API. * Fix reset args. * Execute redefine sym. * Format code. * Do not use verbose mode for CircleCI. * Remove debug code. * Prettify code, add comments. * getSegmentData -> getEmbeddedConstPool * vector -> std::vector. * Make sure we properly clean up intermediate files. * Fix test cases. * Add runtime directory. * Trigger rebuild. * [Merge with master] fix debug script. * Diable affine fusion pass for now. * Support generic fallback const packing mechanism. * Remove debug code. * Handle the case where objcopy is not available. * Fix Windows missing types. * Support int64. * Copy packed constant to a local directory for non-Linux/Mac platforms. * Nit: remove debug code, refactor const pack preprocessing out as a separate function. * Cannot make preprocessConstPack a standalone function because file removers are stack-allocated, and they are deallocated prematurely when function stack gets popped, deleteing intermediate files too early. * Don't require executable filename. * Import ONNX data types directly. * Fix LIT test. * Bug fix, use moved string value. * Remove redundant filenames. * Fix CMake script. * Embed endianness information as a symbol, and check during runtime. * More comments, update lit tests. * Fix lit test on BE machine. * Copyright notices.
This commit is contained in:
parent
8c4d527eea
commit
e0ae583da0
|
@ -71,4 +71,4 @@ cmake -DCMAKE_INSTALL_PREFIX=${INSTALL_PATH} .. \
|
||||||
|
|
||||||
make -j$(nproc) onnx-mlir
|
make -j$(nproc) onnx-mlir
|
||||||
make -j$(nproc) check-onnx-lit
|
make -j$(nproc) check-onnx-lit
|
||||||
make -j$(nproc) check-onnx-backend
|
RUNTIME_DIR=$(pwd)/lib make -j$(nproc) check-onnx-backend
|
||||||
|
|
|
@ -40,14 +40,14 @@ jobs:
|
||||||
command: |
|
command: |
|
||||||
sudo pip install -q -e ./onnx-mlir/third_party/onnx
|
sudo pip install -q -e ./onnx-mlir/third_party/onnx
|
||||||
cd onnx-mlir/build
|
cd onnx-mlir/build
|
||||||
VERBOSE=1 cmake --build . --target check-onnx-backend
|
RUNTIME_DIR=$(pwd)/lib cmake --build . --target check-onnx-backend
|
||||||
- run:
|
- run:
|
||||||
name: Run Unit Tests
|
name: Run Unit Tests
|
||||||
command: |
|
command: |
|
||||||
cd onnx-mlir/build
|
cd onnx-mlir/build
|
||||||
# Need to include the bin directory in $PATH,
|
# Need to include the bin directory in $PATH,
|
||||||
# otherwise CTest fails to find the test executables.
|
# otherwise CTest fails to find the test executables.
|
||||||
PATH=$(pwd)/bin:$PATH make test -j$(nproc)
|
RUNTIME_DIR=$(pwd)/lib PATH=$(pwd)/bin:$PATH make test -j$(nproc)
|
||||||
- run:
|
- run:
|
||||||
name: Run DocCheck
|
name: Run DocCheck
|
||||||
command: cd onnx-mlir/build && cmake --build . --target check-doc
|
command: cd onnx-mlir/build && cmake --build . --target check-doc
|
||||||
|
|
|
@ -269,6 +269,7 @@ set(ONNXMLIRWholeArchiveLibs
|
||||||
OMPromotableConstOperandsOpInterface
|
OMPromotableConstOperandsOpInterface
|
||||||
OMElideConstants
|
OMElideConstants
|
||||||
OMElideKrnlGlobalConstants
|
OMElideKrnlGlobalConstants
|
||||||
|
OMPackKrnlGlobalConstants
|
||||||
OMEnableMemoryPool)
|
OMEnableMemoryPool)
|
||||||
|
|
||||||
# Function to construct linkage option for the static libraries that must be
|
# Function to construct linkage option for the static libraries that must be
|
||||||
|
|
|
@ -7,6 +7,8 @@
|
||||||
// Helper methods for handling input ONNX models.
|
// Helper methods for handling input ONNX models.
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
#include <llvm/Support/Endian.h>
|
||||||
|
#include <llvm/Support/SwapByteOrder.h>
|
||||||
|
|
||||||
#include "src/Builder/FrontendDialectHelper.hpp"
|
#include "src/Builder/FrontendDialectHelper.hpp"
|
||||||
|
|
||||||
|
@ -104,8 +106,16 @@ static std::vector<T> CreateArrayAttribute(onnx::TensorProto initializer) {
|
||||||
std::copy(initializer.raw_data().begin(), initializer.raw_data().end(),
|
std::copy(initializer.raw_data().begin(), initializer.raw_data().end(),
|
||||||
back_inserter(byteInitializer));
|
back_inserter(byteInitializer));
|
||||||
size = initializer.raw_data().size() / sizeof(T);
|
size = initializer.raw_data().size() / sizeof(T);
|
||||||
T *res = reinterpret_cast<T *>(&byteInitializer[0]);
|
T *arrayPtr = reinterpret_cast<T *>(&byteInitializer[0]);
|
||||||
return std::vector<T>(res, res + size);
|
auto array = std::vector<T>(arrayPtr, arrayPtr + size);
|
||||||
|
// Perform byte swap if system endianness is BE.
|
||||||
|
// ONNX tensor content raw data is always in LE.
|
||||||
|
if (llvm::support::endian::system_endianness() !=
|
||||||
|
llvm::support::endianness::little)
|
||||||
|
for (int i = 0; i < array.size(); i++)
|
||||||
|
llvm::sys::swapByteOrder<T>(array[i]);
|
||||||
|
|
||||||
|
return array;
|
||||||
}
|
}
|
||||||
|
|
||||||
// copy, no need to take care of endianness
|
// copy, no need to take care of endianness
|
||||||
|
|
|
@ -25,9 +25,6 @@ if(NOT EXISTS "${LLVM_PROJ_BUILD}/bin/llc")
|
||||||
message(ERROR "Cannot find llc.")
|
message(ERROR "Cannot find llc.")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Get the compiler command name, the C++ compiler is needed to to translate
|
|
||||||
# object files to shared libraries.
|
|
||||||
get_filename_component(CXX_COMPILER_FILENAME ${CMAKE_CXX_COMPILER} NAME)
|
|
||||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ExternalUtil.hpp.in
|
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ExternalUtil.hpp.in
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/ExternalUtil.hpp)
|
${CMAKE_CURRENT_BINARY_DIR}/ExternalUtil.hpp)
|
||||||
|
|
||||||
|
@ -57,6 +54,7 @@ target_link_libraries(MainUtils
|
||||||
OMResultTypeInferenceOpInterface
|
OMResultTypeInferenceOpInterface
|
||||||
OMElideConstants
|
OMElideConstants
|
||||||
OMElideKrnlGlobalConstants
|
OMElideKrnlGlobalConstants
|
||||||
|
OMPackKrnlGlobalConstants
|
||||||
OMEnableMemoryPool
|
OMEnableMemoryPool
|
||||||
OMKrnlToAffine
|
OMKrnlToAffine
|
||||||
OMKrnlToLLVM
|
OMKrnlToLLVM
|
||||||
|
@ -71,6 +69,9 @@ if (INCLUDE_ONNX_ML)
|
||||||
add_dependencies(MainUtils OMMLONNXOpsInc)
|
add_dependencies(MainUtils OMMLONNXOpsInc)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
add_dependencies(onnx-mlir cruntime)
|
||||||
|
add_dependencies(onnx-mlir EmbeddedDataLoader)
|
||||||
|
|
||||||
target_include_directories(onnx-mlir PRIVATE ${ONNX_MLIR_SRC_ROOT})
|
target_include_directories(onnx-mlir PRIVATE ${ONNX_MLIR_SRC_ROOT})
|
||||||
target_include_directories(onnx-mlir PRIVATE ${CMAKE_BINARY_DIR})
|
target_include_directories(onnx-mlir PRIVATE ${CMAKE_BINARY_DIR})
|
||||||
target_include_directories(onnx-mlir PRIVATE ${ONNX_MLIR_BIN_ROOT})
|
target_include_directories(onnx-mlir PRIVATE ${ONNX_MLIR_BIN_ROOT})
|
||||||
|
|
|
@ -38,9 +38,11 @@ struct ONNXConstantOpLowering : public ConversionPattern {
|
||||||
|
|
||||||
// Emit the constant global in Krnl dialect.
|
// Emit the constant global in Krnl dialect.
|
||||||
auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc, memRefType,
|
auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc, memRefType,
|
||||||
rewriter.getI64ArrayAttr(shape),
|
/*shape=*/rewriter.getI64ArrayAttr(shape),
|
||||||
|
/*name=*/
|
||||||
rewriter.getStringAttr("constant_" + std::to_string(constantID)),
|
rewriter.getStringAttr("constant_" + std::to_string(constantID)),
|
||||||
constantOp.value().getValue());
|
/*value=*/constantOp.value().getValue(),
|
||||||
|
/*offset=*/nullptr);
|
||||||
|
|
||||||
// Increment constant ID:
|
// Increment constant ID:
|
||||||
constantID++;
|
constantID++;
|
||||||
|
|
|
@ -196,16 +196,66 @@ def KrnlMemcpyOp : Op<Krnl_Dialect, "memcpy"> {
|
||||||
def KrnlGlobalOp : Op<Krnl_Dialect, "global"> {
|
def KrnlGlobalOp : Op<Krnl_Dialect, "global"> {
|
||||||
let summary = "Krnl global operation";
|
let summary = "Krnl global operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
Operation for holding global data values.
|
Operation for holding global data values. A global constant can have a
|
||||||
|
meaningful name recorded as its `name` attribute. Its content is stored
|
||||||
|
in the `value` dense element attribute. Alternatively, if the constants
|
||||||
|
are packed together, `offset` records the byte offset in the global
|
||||||
|
constant pool from which the current constant is to be read.
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins AnyAttr:$shape, StrAttr:$name, OptionalAttr<AnyAttr>:$value);
|
let arguments = (ins AnyAttr:$shape,
|
||||||
|
StrAttr:$name, OptionalAttr<AnyAttr>:$value, OptionalAttr<I64Attr>:$offset);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef]>:$output);
|
let results = (outs AnyTypeOf<[AnyMemRef]>:$output);
|
||||||
|
|
||||||
let parser = ?;
|
let parser = ?;
|
||||||
let printer = ?;
|
let printer = ?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def KrnlPackedConstantOp : Op<Krnl_Dialect, "packed_const"> {
|
||||||
|
let summary = "Krnl packed constant operation";
|
||||||
|
let description = [{
|
||||||
|
Operation for holding packed constants.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins I64Attr:$size_in_bytes,
|
||||||
|
BoolAttr:$is_le,
|
||||||
|
OptionalAttr<AnyIntElementsAttr<8>>:$value,
|
||||||
|
OptionalAttr<StrAttr>:$file_name);
|
||||||
|
let results = (outs I64:$output);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// The *path* to the file storing the constant pack on disk is encoded
|
||||||
|
// as a global variable in the LLVM module of the lowered model.
|
||||||
|
// getConstPackFilePathSymbolName returns the name of this symbol;
|
||||||
|
// getConstPackFilePathStrLenSymbolName returns the name of the symbol holding
|
||||||
|
// a constant value equal to the length of the file path.
|
||||||
|
static StringRef getConstPackFilePathSymbolName() { return "constPackFilePath"; }
|
||||||
|
static StringRef getConstPackFilePathStrLenSymbolName() { return "constPackFilePathStrLen"; }
|
||||||
|
|
||||||
|
// The *name* of the file storing the constant pack is also recorded for
|
||||||
|
// convenience. Similarly, getConstPackFileNameSymbolName and
|
||||||
|
// getConstPackFileNameStrLenSymbolName returns records the symbol holding
|
||||||
|
// the string constant representing the filename and the length of this
|
||||||
|
// string constant.
|
||||||
|
static StringRef getConstPackFileNameSymbolName() { return "constPackFileName"; }
|
||||||
|
static StringRef getConstPackFileNameStrLenSymbolName() { return "constPackFileNameStrLen"; }
|
||||||
|
|
||||||
|
// We record whether the constant pack is stored in LE byte order as a
|
||||||
|
// int8 symbol; it is to be interpreted as a boolean switch - with 0
|
||||||
|
// meaning that the constant pack is not stored in LE byte order and
|
||||||
|
// non-0 values meaning that it is stored in LE byte order.
|
||||||
|
static StringRef getConstPackIsLESymbolName() { return "constPackIsLE"; }
|
||||||
|
// The name of a function we call to read packed constants embedded within
|
||||||
|
// the current binary executable/library, or in the case of unsupported platform,
|
||||||
|
// from a binary constant pack file.
|
||||||
|
static StringRef getEmbeddedDataLoaderMethodName() {
|
||||||
|
return "getEmbeddedConstPool";
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
let parser = ?;
|
||||||
|
let printer = ?;
|
||||||
|
}
|
||||||
|
|
||||||
def KrnlGetRefOp : Op<Krnl_Dialect, "getref"> {
|
def KrnlGetRefOp : Op<Krnl_Dialect, "getref"> {
|
||||||
let summary = "Krnl a MemRef from within another MemRef starting at a specific offset.";
|
let summary = "Krnl a MemRef from within another MemRef starting at a specific offset.";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
|
@ -1138,7 +1138,6 @@ LogicalResult ONNXReshapeOp::inferShapes() {
|
||||||
if (constantOp) {
|
if (constantOp) {
|
||||||
DenseElementsAttr valueAttribute =
|
DenseElementsAttr valueAttribute =
|
||||||
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
|
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
|
||||||
|
|
||||||
if (!valueAttribute)
|
if (!valueAttribute)
|
||||||
return emitError("DenseElementsAttr expected");
|
return emitError("DenseElementsAttr expected");
|
||||||
// Get dims from valueAttribute.
|
// Get dims from valueAttribute.
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
const std::string kLlcPath = "@LLVM_PROJ_BUILD@/bin/llc";
|
const std::string kLlcPath = "@LLVM_PROJ_BUILD@/bin/llc";
|
||||||
const std::string kCxxPath = "@CMAKE_CXX_COMPILER@";
|
const std::string kCxxPath = "@CMAKE_CXX_COMPILER@";
|
||||||
const std::string kCxxFileName = "@CXX_COMPILER_FILENAME@";
|
const std::string kLinkerPath = "@CMAKE_LINKER@";
|
||||||
const std::string kRuntimeDirPath = "@CMAKE_BINARY_DIR@/lib";
|
const std::string kObjCopyPath = "@CMAKE_OBJCOPY@";
|
||||||
}
|
} // namespace onnx_mlir
|
||||||
|
|
|
@ -9,8 +9,15 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
#include <fcntl.h>
|
#include <fcntl.h>
|
||||||
|
#include <regex>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include <llvm/Support/FileSystem.h>
|
||||||
#include <llvm/Support/Program.h>
|
#include <llvm/Support/Program.h>
|
||||||
|
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||||
|
#include <mlir/IR/SymbolTable.h>
|
||||||
|
|
||||||
#include "src/ExternalUtil.hpp"
|
#include "src/ExternalUtil.hpp"
|
||||||
#include "src/MainUtils.hpp"
|
#include "src/MainUtils.hpp"
|
||||||
|
@ -26,6 +33,63 @@
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace onnx_mlir;
|
using namespace onnx_mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
llvm::Optional<std::string> getEnvVar(std::string name) {
|
||||||
|
if (const char *envVerbose = std::getenv(name.c_str()))
|
||||||
|
return std::string(envVerbose);
|
||||||
|
return llvm::None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper struct to make command construction and execution easy & readable.
|
||||||
|
struct Command {
|
||||||
|
std::string _path;
|
||||||
|
std::vector<std::string> _args;
|
||||||
|
|
||||||
|
Command(std::string exePath)
|
||||||
|
: _path(std::move(exePath)),
|
||||||
|
_args({llvm::sys::path::filename(_path).str()}) {}
|
||||||
|
|
||||||
|
// Append a single string argument.
|
||||||
|
Command &appendStr(const std::string &arg) {
|
||||||
|
_args.emplace_back(arg);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append a list of string arguments.
|
||||||
|
Command &appendList(const std::vector<std::string> &args) {
|
||||||
|
_args.insert(_args.end(), args.begin(), args.end());
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset arguments.
|
||||||
|
Command &resetArgs() {
|
||||||
|
auto exeFileName = _args.front();
|
||||||
|
_args.clear();
|
||||||
|
_args.emplace_back(exeFileName);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute command.
|
||||||
|
void exec() {
|
||||||
|
auto argsRef = std::vector<llvm::StringRef>(_args.begin(), _args.end());
|
||||||
|
bool verbose = false;
|
||||||
|
if (const auto &verboseStr = getEnvVar("VERBOSE"))
|
||||||
|
istringstream(verboseStr.getValue()) >> verbose;
|
||||||
|
|
||||||
|
// If in verbose mode, print out command before execution.
|
||||||
|
if (verbose)
|
||||||
|
cout << llvm::join(argsRef, " ") << "\n";
|
||||||
|
int rc = llvm::sys::ExecuteAndWait(_path, llvm::makeArrayRef(argsRef));
|
||||||
|
|
||||||
|
if (rc != 0) {
|
||||||
|
fprintf(stderr, "%s\n", llvm::join(argsRef, " ").c_str());
|
||||||
|
llvm_unreachable("Command execution failed.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
|
void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
|
||||||
mlir::OwningModuleRef &module) {
|
mlir::OwningModuleRef &module) {
|
||||||
// Handle '.mlir' input to the ONNX MLIR frontend.
|
// Handle '.mlir' input to the ONNX MLIR frontend.
|
||||||
|
@ -50,31 +114,123 @@ void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
|
||||||
|
|
||||||
void compileModuleToSharedLibrary(
|
void compileModuleToSharedLibrary(
|
||||||
const mlir::OwningModuleRef &module, string outputBaseName) {
|
const mlir::OwningModuleRef &module, string outputBaseName) {
|
||||||
|
// Extract constant pack file name, which is embedded as a symbol in the
|
||||||
|
// module being compiled.
|
||||||
|
auto constPackFilePathSym = (*module).lookupSymbol<mlir::LLVM::GlobalOp>(
|
||||||
|
mlir::KrnlPackedConstantOp::getConstPackFilePathSymbolName());
|
||||||
|
auto constPackFilePath = constPackFilePathSym.valueAttr()
|
||||||
|
.dyn_cast_or_null<mlir::StringAttr>()
|
||||||
|
.getValue()
|
||||||
|
.str();
|
||||||
|
llvm::FileRemover constPackRemover(constPackFilePath);
|
||||||
|
|
||||||
|
llvm::Optional<std::string> constPackObjPath;
|
||||||
|
#if __APPLE__
|
||||||
|
// Create a empty stub file, compile it to an empty obj file.
|
||||||
|
llvm::SmallVector<char, 20> stubSrcPath;
|
||||||
|
llvm::sys::fs::createTemporaryFile("stub", "cpp", stubSrcPath);
|
||||||
|
llvm::FileRemover subSrcRemover(stubSrcPath);
|
||||||
|
std::string stubSrcPathStr(stubSrcPath.begin(), stubSrcPath.end());
|
||||||
|
Command createStubObj(/*exePath=*/kCxxPath);
|
||||||
|
std::string stubObjPathStr = stubSrcPathStr + ".o";
|
||||||
|
createStubObj.appendList({"-o", stubObjPathStr})
|
||||||
|
.appendList({"-c", stubSrcPathStr})
|
||||||
|
.exec();
|
||||||
|
llvm::FileRemover stubObjRemover(stubObjPathStr);
|
||||||
|
|
||||||
|
// Embed data into the empty stub obj file.
|
||||||
|
constPackObjPath = constPackFilePath + ".o";
|
||||||
|
Command genParamObj(/*exePath=*/kLinkerPath);
|
||||||
|
genParamObj.appendStr("-r")
|
||||||
|
.appendList({"-o", constPackObjPath.getValue()})
|
||||||
|
.appendList({"-sectcreate", "binary", "param", constPackFilePath})
|
||||||
|
.appendStr(stubObjPathStr)
|
||||||
|
.exec();
|
||||||
|
llvm::FileRemover constPackObjRemover(constPackObjPath.getValue());
|
||||||
|
|
||||||
|
#elif __linux__
|
||||||
|
// Create param.o holding packed parameter values.
|
||||||
|
constPackObjPath = constPackFilePath + ".o";
|
||||||
|
Command genParamObj(/*exePath=*/kLinkerPath);
|
||||||
|
genParamObj.appendStr("-r")
|
||||||
|
.appendList({"-b", "binary"})
|
||||||
|
.appendList({"-o", constPackObjPath.getValue()})
|
||||||
|
.appendStr(constPackFilePath)
|
||||||
|
.exec();
|
||||||
|
llvm::FileRemover constPackObjRemover(constPackObjPath.getValue());
|
||||||
|
|
||||||
|
// Figure out what is the default symbol name describing the start/end
|
||||||
|
// address of the embedded data.
|
||||||
|
std::regex e("[^0-9A-Za-z]");
|
||||||
|
auto sanitizedName =
|
||||||
|
"_binary_" + std::regex_replace(constPackFilePath, e, "_");
|
||||||
|
|
||||||
|
// Rename the symbols to saner ones expected by the runtime function.
|
||||||
|
Command redefineSym(/*exePath=*/kObjCopyPath);
|
||||||
|
redefineSym.appendStr("--redefine-sym")
|
||||||
|
.appendStr(sanitizedName + "_start=_binary_param_bin_start")
|
||||||
|
.appendStr(constPackObjPath.getValue())
|
||||||
|
.exec();
|
||||||
|
redefineSym.resetArgs()
|
||||||
|
.appendStr("--redefine-sym")
|
||||||
|
.appendStr(sanitizedName + "_end=_binary_param_bin_end")
|
||||||
|
.appendStr(constPackObjPath.getValue())
|
||||||
|
.exec();
|
||||||
|
|
||||||
|
#else
|
||||||
|
llvm::SmallVector<char, 10> permConstPackFileName(
|
||||||
|
constPackFilePath.begin(), constPackFilePath.end());
|
||||||
|
llvm::sys::path::replace_extension(permConstPackFileName, "bin");
|
||||||
|
std::string permConstPackFileNameStr(
|
||||||
|
permConstPackFileName.begin(), permConstPackFileName.end());
|
||||||
|
auto constPackFileName = llvm::sys::path::filename(outputBaseName) + "." +
|
||||||
|
llvm::sys::path::filename(permConstPackFileNameStr);
|
||||||
|
llvm::sys::fs::rename(constPackFilePath, constPackFileName);
|
||||||
|
|
||||||
|
mlir::Builder builder(*module);
|
||||||
|
(*module)
|
||||||
|
.lookupSymbol<mlir::LLVM::GlobalOp>(
|
||||||
|
mlir::KrnlPackedConstantOp::getConstPackFileNameSymbolName())
|
||||||
|
.valueAttr(builder.getStringAttr(constPackFileName.str()));
|
||||||
|
(*module)
|
||||||
|
.lookupSymbol<mlir::LLVM::GlobalOp>(
|
||||||
|
mlir::KrnlPackedConstantOp::getConstPackFileNameStrLenSymbolName())
|
||||||
|
.valueAttr(builder.getI64IntegerAttr(constPackFileName.str().size()));
|
||||||
|
#endif
|
||||||
|
|
||||||
// Write LLVM bitcode.
|
// Write LLVM bitcode.
|
||||||
string outputFilename = outputBaseName + ".bc";
|
string outputFilename = outputBaseName + ".bc";
|
||||||
error_code error;
|
error_code error;
|
||||||
llvm::raw_fd_ostream moduleBitcodeStream(
|
llvm::raw_fd_ostream moduleBitcodeStream(
|
||||||
outputFilename, error, llvm::sys::fs::F_None);
|
outputFilename, error, llvm::sys::fs::F_None);
|
||||||
|
|
||||||
llvm::WriteBitcodeToFile(
|
llvm::WriteBitcodeToFile(
|
||||||
*mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream);
|
*mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream);
|
||||||
moduleBitcodeStream.flush();
|
moduleBitcodeStream.flush();
|
||||||
|
llvm::FileRemover bcRemover(outputFilename);
|
||||||
|
|
||||||
// Compile bitcode to object file.
|
// Compile LLVM bitcode to object file.
|
||||||
std::vector<std::string> llcArgs = {
|
Command llvmToObj(/*exePath=*/kLlcPath);
|
||||||
"llc", "-filetype=obj", "-relocation-model=pic", outputFilename};
|
llvmToObj.appendStr("-filetype=obj");
|
||||||
auto llcArgStrRefs =
|
llvmToObj.appendStr("-relocation-model=pic");
|
||||||
std::vector<llvm::StringRef>(llcArgs.begin(), llcArgs.end());
|
llvmToObj.appendStr(outputFilename);
|
||||||
llvm::sys::ExecuteAndWait(kLlcPath, llvm::makeArrayRef(llcArgStrRefs));
|
llvmToObj.exec();
|
||||||
|
std::string modelObjPath = outputBaseName + ".o";
|
||||||
|
llvm::FileRemover modelObjRemover(modelObjPath);
|
||||||
|
|
||||||
// Link with runtime.
|
std::string runtimeDirInclFlag;
|
||||||
// TODO(tjingrant): link with runtime library in LLVM, and make the shared
|
if (getEnvVar("RUNTIME_DIR").hasValue())
|
||||||
// library more self-contained.
|
runtimeDirInclFlag = "-L" + getEnvVar("RUNTIME_DIR").getValue();
|
||||||
std::vector<std::string> cxxArgs = {kCxxFileName, "-shared", "-fPIC",
|
|
||||||
outputBaseName + ".o", "-o", outputBaseName + ".so",
|
// Link everything into a shared object.
|
||||||
"-L" + kRuntimeDirPath, "-lcruntime", "-Wl,-rpath," + kRuntimeDirPath};
|
Command link(kCxxPath);
|
||||||
auto argsArrayRefVector =
|
link.appendList({"-shared", "-fPIC"})
|
||||||
std::vector<llvm::StringRef>(cxxArgs.begin(), cxxArgs.end());
|
.appendStr(modelObjPath)
|
||||||
llvm::sys::ExecuteAndWait(kCxxPath, llvm::makeArrayRef(argsArrayRefVector));
|
.appendStr(constPackObjPath.getValueOr(""))
|
||||||
|
.appendList({"-o", outputBaseName + ".so"})
|
||||||
|
.appendStr(runtimeDirInclFlag)
|
||||||
|
.appendList({"-lEmbeddedDataLoader", "-lcruntime"})
|
||||||
|
.exec();
|
||||||
}
|
}
|
||||||
|
|
||||||
void registerDialects() {
|
void registerDialects() {
|
||||||
|
@ -98,6 +254,7 @@ void addONNXToMLIRPasses(mlir::PassManager &pm) {
|
||||||
|
|
||||||
void addONNXToKrnlPasses(mlir::PassManager &pm) {
|
void addONNXToKrnlPasses(mlir::PassManager &pm) {
|
||||||
pm.addPass(mlir::createLowerToKrnlPass());
|
pm.addPass(mlir::createLowerToKrnlPass());
|
||||||
|
pm.addPass(mlir::createPackKrnlGlobalConstantsPass());
|
||||||
// An additional pass of canonicalization is helpful because lowering
|
// An additional pass of canonicalization is helpful because lowering
|
||||||
// from ONNX dialect to Standard dialect exposes additional canonicalization
|
// from ONNX dialect to Standard dialect exposes additional canonicalization
|
||||||
// oppertunities.
|
// oppertunities.
|
||||||
|
@ -110,7 +267,7 @@ void addONNXToKrnlPasses(mlir::PassManager &pm) {
|
||||||
void addKrnlToAffinePasses(mlir::PassManager &pm) {
|
void addKrnlToAffinePasses(mlir::PassManager &pm) {
|
||||||
pm.addPass(mlir::createLowerKrnlPass());
|
pm.addPass(mlir::createLowerKrnlPass());
|
||||||
// Fuse loops in Affine dialect.
|
// Fuse loops in Affine dialect.
|
||||||
pm.addPass(mlir::createLoopFusionPass());
|
// pm.addPass(mlir::createLoopFusionPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
void addKrnlToLLVMPasses(mlir::PassManager &pm) {
|
void addKrnlToLLVMPasses(mlir::PassManager &pm) {
|
||||||
|
|
|
@ -43,4 +43,7 @@ std::unique_ptr<Pass> createElideConstGlobalValuePass();
|
||||||
/// Pass for lowering Krnl dialect to LLVM dialect.
|
/// Pass for lowering Krnl dialect to LLVM dialect.
|
||||||
std::unique_ptr<Pass> createKrnlLowerToLLVMPass();
|
std::unique_ptr<Pass> createKrnlLowerToLLVMPass();
|
||||||
|
|
||||||
|
/// Pass for packing Krnl global constants.
|
||||||
|
std::unique_ptr<Pass> createPackKrnlGlobalConstantsPass();
|
||||||
|
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Create shared libcruntime.so since model.so linkage for backend tests
|
# Create shared libcruntime.so since model.so linkage for backend tests
|
||||||
# will fail on x86 Linux if cruntime is statically linked.
|
# will fail on x86 Linux if cruntime is statically linked.
|
||||||
add_library(cruntime SHARED
|
add_library(cruntime STATIC
|
||||||
DynMemRef.cpp
|
DynMemRef.cpp
|
||||||
DynMemRef.h
|
DynMemRef.h
|
||||||
DataType.h)
|
DataType.h)
|
||||||
|
@ -34,6 +34,13 @@ target_include_directories(PyRuntime PRIVATE
|
||||||
${ONNX_MLIR_BIN_ROOT}
|
${ONNX_MLIR_BIN_ROOT}
|
||||||
${ONNX_MLIR_SRC_ROOT})
|
${ONNX_MLIR_SRC_ROOT})
|
||||||
|
|
||||||
|
add_library(EmbeddedDataLoader STATIC
|
||||||
|
GetEmbeddedConstPool.h
|
||||||
|
GetEmbeddedConstPool.cpp)
|
||||||
|
set_target_properties(EmbeddedDataLoader PROPERTIES
|
||||||
|
POSITION_INDEPENDENT_CODE TRUE)
|
||||||
|
|
||||||
add_dependencies(PyRuntime cruntime)
|
add_dependencies(PyRuntime cruntime)
|
||||||
install(FILES DynMemRef.h DESTINATION include)
|
install(FILES DynMemRef.h DESTINATION include)
|
||||||
install(TARGETS cruntime DESTINATION lib)
|
install(TARGETS cruntime DESTINATION lib)
|
||||||
|
install(TARGETS EmbeddedDataLoader DESTINATION lib)
|
||||||
|
|
|
@ -14,6 +14,10 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
#include "DynMemRef.h"
|
#include "DynMemRef.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
|
@ -0,0 +1,88 @@
|
||||||
|
//===---- GetEmbeddedConstPool.h - Get Embedded Const Pool API Func Impl---===//
|
||||||
|
//
|
||||||
|
// Copyright 2019-2020 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file contains runtime API implementation to extract constant pool values
|
||||||
|
// embedded within the shared library binary files.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "GetEmbeddedConstPool.h"
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
// Adapted from:
|
||||||
|
// https://developer.ibm.com/technologies/systems/articles/au-endianc/
|
||||||
|
const int i = 1;
|
||||||
|
#define IS_SYSTEM_LE() (!((*(char *)&i) == 0))
|
||||||
|
|
||||||
|
#define XOR(a, b) (!(a) != !(b))
|
||||||
|
|
||||||
|
extern const char constPackIsLE;
|
||||||
|
|
||||||
|
void checkEndianness() {
|
||||||
|
if (XOR(IS_SYSTEM_LE(), constPackIsLE)) {
|
||||||
|
fprintf(stderr, "Constant pack is stored in a byte order that is not "
|
||||||
|
"native to this current system.");
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#if __APPLE__
|
||||||
|
#include <mach-o/getsect.h>
|
||||||
|
extern const struct mach_header_64 _mh_dylib_header;
|
||||||
|
|
||||||
|
void *getEmbeddedConstPool(int64_t size_in_byte) {
|
||||||
|
checkEndianness();
|
||||||
|
size_t size = size_in_byte;
|
||||||
|
unsigned char *data =
|
||||||
|
getsectiondata(&_mh_dylib_header, "binary", "param", &size);
|
||||||
|
float *data_ptr = (float *)data;
|
||||||
|
void *buffer = malloc(size);
|
||||||
|
memcpy(buffer, data, size);
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
#elif __linux__
|
||||||
|
extern char _binary_param_bin_start;
|
||||||
|
extern char _binary_param_bin_end;
|
||||||
|
|
||||||
|
void *getEmbeddedConstPool(int64_t _) {
|
||||||
|
checkEndianness();
|
||||||
|
auto size = (unsigned int)(&_binary_param_bin_end - &_binary_param_bin_start);
|
||||||
|
void *buffer = malloc(size);
|
||||||
|
memcpy(buffer, &_binary_param_bin_start, size);
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
extern char constPackFileName[];
|
||||||
|
extern int64_t constPackFileNameStrLen;
|
||||||
|
|
||||||
|
void *getEmbeddedConstPool(int64_t _) {
|
||||||
|
checkEndianness();
|
||||||
|
char *fname = (char *)calloc(1, constPackFileNameStrLen + 1);
|
||||||
|
memcpy(fname, constPackFileName, constPackFileNameStrLen);
|
||||||
|
|
||||||
|
// Adapted from https://stackoverflow.com/a/22059317 .
|
||||||
|
FILE *fileptr;
|
||||||
|
char *buffer;
|
||||||
|
long filelen;
|
||||||
|
|
||||||
|
fileptr = fopen(fname, "rb"); // Open the file in binary mode
|
||||||
|
fseek(fileptr, 0, SEEK_END); // Jump to the end of the file
|
||||||
|
filelen = ftell(fileptr); // Get the current byte offset in the file
|
||||||
|
rewind(fileptr); // Jump back to the beginning of the file
|
||||||
|
|
||||||
|
buffer = (char *)malloc(filelen * sizeof(char)); // Enough memory for the file
|
||||||
|
fread(buffer, filelen, 1, fileptr); // Read in the entire file
|
||||||
|
fclose(fileptr); // Close the file
|
||||||
|
|
||||||
|
return (void *)buffer;
|
||||||
|
}
|
||||||
|
#endif
|
|
@ -0,0 +1,18 @@
|
||||||
|
//===---- GetEmbeddedConstPool.h - Get Embedded Const Pool API Func Decl---===//
|
||||||
|
//
|
||||||
|
// Copyright 2019-2020 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file contains runtime API declarations to extract constant pool values
|
||||||
|
// embedded within the shared library binary files.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
void *getEmbeddedConstPool(int64_t size_in_byte);
|
||||||
|
}
|
|
@ -0,0 +1,95 @@
|
||||||
|
//===----- BinaryDecoder.cpp - Decode binary files into typed arrays ------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019-2020 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file contains implementation of a utility called BinaryDecoder, which
|
||||||
|
// decodes a sequence of binary data within a binary file specified by an
|
||||||
|
// offset and a length into a typed array and print to stdout.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "onnx/onnx_pb.h"
|
||||||
|
#include <llvm/Support/CommandLine.h>
|
||||||
|
|
||||||
|
#if defined(_WIN32)
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
typedef uint8_t u_int8_t;
|
||||||
|
typedef uint16_t u_int16_t;
|
||||||
|
typedef uint32_t u_int32_t;
|
||||||
|
typedef uint64_t u_int64_t;
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
llvm::cl::opt<std::string> Filename(
|
||||||
|
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::Required);
|
||||||
|
llvm::cl::opt<int64_t> Start("s",
|
||||||
|
llvm::cl::desc("Specify the index of the starting byte"),
|
||||||
|
llvm::cl::value_desc("start"), llvm::cl::Required);
|
||||||
|
llvm::cl::opt<int64_t> Size("n",
|
||||||
|
llvm::cl::desc("Specify the number of bytes of data to decode"),
|
||||||
|
llvm::cl::value_desc("size"), llvm::cl::Required);
|
||||||
|
llvm::cl::opt<bool> Remove(
|
||||||
|
"rm", llvm::cl::desc(
|
||||||
|
"Whether to remove the file being decoded after inspection."));
|
||||||
|
|
||||||
|
llvm::cl::opt<onnx::TensorProto::DataType> DataType(
|
||||||
|
llvm::cl::desc("Choose data type to decode:"),
|
||||||
|
llvm::cl::values(clEnumVal(onnx::TensorProto::FLOAT, "FLOAT"),
|
||||||
|
clEnumVal(onnx::TensorProto::UINT8, "UINT8"),
|
||||||
|
clEnumVal(onnx::TensorProto::INT8, "INT8"),
|
||||||
|
clEnumVal(onnx::TensorProto::UINT16, "UINT16"),
|
||||||
|
clEnumVal(onnx::TensorProto::INT16, "INT16"),
|
||||||
|
clEnumVal(onnx::TensorProto::INT32, "INT32"),
|
||||||
|
clEnumVal(onnx::TensorProto::INT64, "INT64"),
|
||||||
|
clEnumVal(onnx::TensorProto::STRING, "STRING"),
|
||||||
|
clEnumVal(onnx::TensorProto::BOOL, "BOOL"),
|
||||||
|
clEnumVal(onnx::TensorProto::FLOAT16, "FLOAT16"),
|
||||||
|
clEnumVal(onnx::TensorProto::DOUBLE, "DOUBLE"),
|
||||||
|
clEnumVal(onnx::TensorProto::UINT32, "UINT32"),
|
||||||
|
clEnumVal(onnx::TensorProto::UINT64, "UINT64")));
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
int printBuffer(std::vector<char> buffer) {
|
||||||
|
auto *ptr = (T *)&buffer[0];
|
||||||
|
auto data = std::vector<T>(ptr, ptr + buffer.size() / sizeof(T));
|
||||||
|
for (const auto &elem : data)
|
||||||
|
std::cout << elem << " ";
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
llvm::cl::ParseCommandLineOptions(argc, argv);
|
||||||
|
std::vector<char> buffer(Size);
|
||||||
|
std::ifstream file(Filename, std::ios::in | std::ios::binary);
|
||||||
|
if (!file)
|
||||||
|
return -1;
|
||||||
|
file.seekg(Start, file.beg);
|
||||||
|
file.read(&buffer[0], Size);
|
||||||
|
file.close();
|
||||||
|
|
||||||
|
if (Remove)
|
||||||
|
llvm::sys::fs::remove(Filename);
|
||||||
|
|
||||||
|
#define PRINT_BUFFER_FOR_TYPE(ONNX_TYPE, CPP_TYPE) \
|
||||||
|
if (DataType == ONNX_TYPE) \
|
||||||
|
return printBuffer<CPP_TYPE>(buffer);
|
||||||
|
|
||||||
|
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::UINT8, u_int8_t);
|
||||||
|
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::UINT16, u_int16_t);
|
||||||
|
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::INT16, int16_t);
|
||||||
|
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::INT32, int32_t);
|
||||||
|
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::INT64, int64_t);
|
||||||
|
|
||||||
|
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::FLOAT, float);
|
||||||
|
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::DOUBLE, double);
|
||||||
|
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::UINT32, u_int32_t);
|
||||||
|
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::UINT64, u_int64_t);
|
||||||
|
}
|
|
@ -0,0 +1,9 @@
|
||||||
|
add_executable(binary-decoder BinaryDecoder.cpp)
|
||||||
|
target_link_libraries(binary-decoder
|
||||||
|
${LLVMSupport}
|
||||||
|
${LLVMDemangle}
|
||||||
|
${CURSES_LIBRARIES}
|
||||||
|
${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
|
||||||
|
message(STATUS "incl dir" ${ONNX_INCLUDE_DIRS})
|
||||||
|
target_link_libraries(binary-decoder onnx)
|
|
@ -1 +1,2 @@
|
||||||
add_subdirectory(ONNXMLIROpt)
|
add_subdirectory(ONNXMLIROpt)
|
||||||
|
add_subdirectory(BinaryDecoder)
|
|
@ -27,7 +27,8 @@ target_link_libraries(OMKrnlToLLVM
|
||||||
onnx)
|
onnx)
|
||||||
|
|
||||||
add_library(OMElideKrnlGlobalConstants
|
add_library(OMElideKrnlGlobalConstants
|
||||||
ElideKrnlGlobalConstants.cpp)
|
ElideKrnlGlobalConstants.cpp
|
||||||
|
ElideKrnlGlobalConstants.hpp)
|
||||||
target_include_directories(OMElideKrnlGlobalConstants
|
target_include_directories(OMElideKrnlGlobalConstants
|
||||||
PRIVATE
|
PRIVATE
|
||||||
${ONNX_MLIR_SRC_ROOT}
|
${ONNX_MLIR_SRC_ROOT}
|
||||||
|
@ -38,6 +39,14 @@ add_dependencies(OMElideKrnlGlobalConstants OMKrnlOpsInc)
|
||||||
# Linking dependencies
|
# Linking dependencies
|
||||||
add_dependencies(OMElideKrnlGlobalConstants OMKrnlOps)
|
add_dependencies(OMElideKrnlGlobalConstants OMKrnlOps)
|
||||||
|
|
||||||
|
add_library(OMPackKrnlGlobalConstants
|
||||||
|
PackKrnlGlobalConstants.cpp)
|
||||||
|
target_include_directories(OMPackKrnlGlobalConstants
|
||||||
|
PRIVATE
|
||||||
|
${ONNX_MLIR_SRC_ROOT}
|
||||||
|
${ONNX_MLIR_BIN_ROOT}
|
||||||
|
${ONNX_MLIR_SRC_ROOT})
|
||||||
|
|
||||||
add_library(OMEnableMemoryPool
|
add_library(OMEnableMemoryPool
|
||||||
EnableMemoryPool.cpp)
|
EnableMemoryPool.cpp)
|
||||||
target_include_directories(OMEnableMemoryPool
|
target_include_directories(OMEnableMemoryPool
|
||||||
|
@ -45,6 +54,9 @@ target_include_directories(OMEnableMemoryPool
|
||||||
${ONNX_MLIR_SRC_ROOT}
|
${ONNX_MLIR_SRC_ROOT}
|
||||||
${ONNX_MLIR_BIN_ROOT}
|
${ONNX_MLIR_BIN_ROOT}
|
||||||
${ONNX_MLIR_SRC_ROOT})
|
${ONNX_MLIR_SRC_ROOT})
|
||||||
|
add_dependencies(OMPackKrnlGlobalConstants
|
||||||
|
OMKrnlOps)
|
||||||
|
|
||||||
target_link_libraries(OMEnableMemoryPool
|
target_link_libraries(OMEnableMemoryPool
|
||||||
onnx)
|
onnx)
|
||||||
add_dependencies(OMEnableMemoryPool
|
add_dependencies(OMEnableMemoryPool
|
||||||
|
|
|
@ -22,34 +22,31 @@
|
||||||
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
||||||
#include "src/Pass/Passes.hpp"
|
#include "src/Pass/Passes.hpp"
|
||||||
|
|
||||||
|
#include "ElideKrnlGlobalConstants.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace {
|
const int64_t KrnlConstGlobalValueElision::kDefaultElisionThreshold = 32;
|
||||||
|
|
||||||
/*!
|
mlir::LogicalResult KrnlConstGlobalValueElision::matchAndRewrite(
|
||||||
* RewritePattern that replaces existing constant Krnl global values
|
mlir::KrnlGlobalOp op, mlir::PatternRewriter &rewriter) const {
|
||||||
* with a similar operation which preserves all attributes except the value
|
auto loc = op.getLoc();
|
||||||
* attribute.
|
|
||||||
*/
|
|
||||||
|
|
||||||
class KrnlConstGlobalValueElision : public OpRewritePattern<KrnlGlobalOp> {
|
if (op.value().hasValue()) {
|
||||||
public:
|
const auto &valAttr = op.valueAttr().dyn_cast_or_null<DenseElementsAttr>();
|
||||||
using OpRewritePattern<KrnlGlobalOp>::OpRewritePattern;
|
if (valAttr.getNumElements() > elisionThreshold) {
|
||||||
|
IntegerAttr offsetAttr = op.offset() ? op.offsetAttr() : nullptr;
|
||||||
LogicalResult matchAndRewrite(
|
auto newGlobalOp = rewriter.create<KrnlGlobalOp>(loc,
|
||||||
KrnlGlobalOp op, PatternRewriter &rewriter) const override {
|
op.getResult().getType(), /*shape=*/op.shape(),
|
||||||
auto loc = op.getLoc();
|
/*name=*/op.name(), /*value=*/nullptr, /*offset=*/offsetAttr);
|
||||||
|
|
||||||
if (op.value().hasValue()) {
|
|
||||||
auto newGlobalOp = rewriter.create<KrnlGlobalOp>(
|
|
||||||
loc, op.getResult().getType(), op.shape(), op.name(), nullptr);
|
|
||||||
rewriter.replaceOp(op, newGlobalOp.getResult());
|
rewriter.replaceOp(op, newGlobalOp.getResult());
|
||||||
}
|
}
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
/*!
|
/*!
|
||||||
* Function pass that performs constant value elision of Krnl globals.
|
* Function pass that performs constant value elision of Krnl globals.
|
||||||
*/
|
*/
|
||||||
|
@ -66,6 +63,7 @@ public:
|
||||||
applyPatternsAndFoldGreedily(function, patterns);
|
applyPatternsAndFoldGreedily(function, patterns);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<Pass> mlir::createElideConstGlobalValuePass() {
|
std::unique_ptr<Pass> mlir::createElideConstGlobalValuePass() {
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* RewritePattern that replaces existing constant Krnl global values
|
||||||
|
* with a similar operation which preserves all attributes except the value
|
||||||
|
* attribute.
|
||||||
|
*/
|
||||||
|
class KrnlConstGlobalValueElision
|
||||||
|
: public mlir::OpRewritePattern<mlir::KrnlGlobalOp> {
|
||||||
|
public:
|
||||||
|
/*
|
||||||
|
* A threshold value specifying the maximum number of elements a constant
|
||||||
|
* operation can hold as an attribute. If the number exceeds this threshold,
|
||||||
|
* constants will be packed together and, in the case where `move-to-file`
|
||||||
|
* option is enabled, stored as a binary file on disk. This can help preserve
|
||||||
|
* readability of IR dump and improve compilation speed.
|
||||||
|
*/
|
||||||
|
static const int64_t kDefaultElisionThreshold;
|
||||||
|
|
||||||
|
int64_t elisionThreshold;
|
||||||
|
|
||||||
|
using mlir::OpRewritePattern<mlir::KrnlGlobalOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
explicit KrnlConstGlobalValueElision(
|
||||||
|
mlir::MLIRContext *context, int64_t elisionThreshold)
|
||||||
|
: OpRewritePattern(context), elisionThreshold(elisionThreshold) {}
|
||||||
|
|
||||||
|
mlir::LogicalResult matchAndRewrite(
|
||||||
|
mlir::KrnlGlobalOp op, mlir::PatternRewriter &rewriter) const override;
|
||||||
|
};
|
|
@ -32,10 +32,9 @@ namespace {
|
||||||
static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName,
|
static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName,
|
||||||
ModuleOp module, mlir::LLVM::LLVMType funcType, PatternRewriter &rewriter) {
|
ModuleOp module, mlir::LLVM::LLVMType funcType, PatternRewriter &rewriter) {
|
||||||
auto *context = module.getContext();
|
auto *context = module.getContext();
|
||||||
if (module.lookupSymbol<LLVM::LLVMFuncOp>(funcName)) {
|
if (auto sym = module.lookupSymbol<LLVM::LLVMFuncOp>(funcName)) {
|
||||||
auto symbolRef = SymbolRefAttr::get(funcName, context);
|
assert(sym.getType() == funcType && "wrong symbol type");
|
||||||
assert(symbolRef.getType() == funcType && "wrong symbol type");
|
return SymbolRefAttr::get(funcName, context);
|
||||||
return symbolRef;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert the function into the body of the parent module.
|
// Insert the function into the body of the parent module.
|
||||||
|
@ -203,55 +202,80 @@ public:
|
||||||
// The llvm type of the global (example: [2 x [8 x float]])
|
// The llvm type of the global (example: [2 x [8 x float]])
|
||||||
auto llvmGlobalType = globalType.cast<LLVM::LLVMType>();
|
auto llvmGlobalType = globalType.cast<LLVM::LLVMType>();
|
||||||
|
|
||||||
{
|
mlir::Value alloc;
|
||||||
OpBuilder::InsertionGuard insertGuard(rewriter);
|
if (krnlGlobalOp.value().hasValue()) {
|
||||||
rewriter.setInsertionPointToStart(module.getBody());
|
{
|
||||||
|
OpBuilder::InsertionGuard insertGuard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(module.getBody());
|
||||||
|
|
||||||
assert(krnlGlobalOp.value().hasValue() &&
|
assert(krnlGlobalOp.value().hasValue() &&
|
||||||
"Krnl Global must always have a value");
|
"Krnl Global must always have a value");
|
||||||
global = rewriter.create<LLVM::GlobalOp>(loc, llvmGlobalType,
|
global = rewriter.create<LLVM::GlobalOp>(loc, llvmGlobalType,
|
||||||
/*isConstant=*/true, LLVM::Linkage::Internal, name,
|
/*isConstant=*/true, LLVM::Linkage::Internal, name,
|
||||||
krnlGlobalOp.value().getValue());
|
krnlGlobalOp.value().getValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some frequently used types.
|
||||||
|
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
||||||
|
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||||
|
|
||||||
|
// Allocate the memory where the constants will be used from.
|
||||||
|
// This is a region of local memory and needs to be emitted as an alloca.
|
||||||
|
auto one = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, llvmI64Ty, rewriter.getI64IntegerAttr(1));
|
||||||
|
alloc = rewriter.create<LLVM::AllocaOp>(
|
||||||
|
loc, llvmGlobalType.getPointerTo(), one, /*alignment=*/0);
|
||||||
|
|
||||||
|
// Copy constant value into the local alloca:
|
||||||
|
// - Bitcast alloc to i8*
|
||||||
|
Value int8PtrAlloc =
|
||||||
|
rewriter.create<LLVM::BitcastOp>(loc, llvmI8PtrTy, alloc);
|
||||||
|
// - Bitcast global to i8*
|
||||||
|
Value globalValue = rewriter.create<LLVM::AddressOfOp>(loc, global);
|
||||||
|
Value i8PtrGlobal =
|
||||||
|
rewriter.create<LLVM::BitcastOp>(loc, llvmI8PtrTy, globalValue);
|
||||||
|
// - Set size.
|
||||||
|
Value memRefElementSize =
|
||||||
|
rewriter.create<LLVM::ConstantOp>(loc, llvmI64Ty,
|
||||||
|
rewriter.getI64IntegerAttr(getMemRefEltSizeInBytes(memRefTy)));
|
||||||
|
Value numElementsValue = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, llvmI64Ty, rewriter.getI64IntegerAttr(numElements));
|
||||||
|
Value totalElementsSize = rewriter.create<LLVM::MulOp>(
|
||||||
|
loc, memRefElementSize, numElementsValue);
|
||||||
|
Value int64Size =
|
||||||
|
rewriter.create<LLVM::SExtOp>(loc, llvmI64Ty, totalElementsSize);
|
||||||
|
// - Set volatile.
|
||||||
|
Value isVolatile = rewriter.create<LLVM::ConstantOp>(loc,
|
||||||
|
LLVM::LLVMType::getInt1Ty(llvmDialect),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
|
||||||
|
// - Copy constant data into the alloca.
|
||||||
|
auto memcpyRef = getOrInsertMemcpy(rewriter, module, llvmDialect);
|
||||||
|
rewriter.create<CallOp>(loc, memcpyRef,
|
||||||
|
LLVM::LLVMType::getVoidTy(llvmDialect),
|
||||||
|
ArrayRef<Value>({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile}));
|
||||||
|
} else {
|
||||||
|
// Some frequently used types.
|
||||||
|
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
||||||
|
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||||
|
|
||||||
|
// Allocate the memory where the constants will be used from.
|
||||||
|
// This is a region of local memory and needs to be emitted as an alloca.
|
||||||
|
auto one = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, llvmI64Ty, rewriter.getI64IntegerAttr(1));
|
||||||
|
|
||||||
|
auto base = module.lookupSymbol<LLVM::GlobalOp>("packedConst");
|
||||||
|
assert(base && "Cannot find symbol packedConst.");
|
||||||
|
|
||||||
|
Value constPackBasePtrAddr =
|
||||||
|
rewriter.create<LLVM::AddressOfOp>(loc, base);
|
||||||
|
Value constPackBasePtr = rewriter.create<LLVM::LoadOp>(
|
||||||
|
loc, base.getType(), constPackBasePtrAddr);
|
||||||
|
auto offset = rewriter.create<LLVM::ConstantOp>(loc, llvmI64Ty,
|
||||||
|
rewriter.getI64IntegerAttr(
|
||||||
|
krnlGlobalOp.offsetAttr().getValue().getSExtValue()));
|
||||||
|
alloc = rewriter.create<LLVM::GEPOp>(
|
||||||
|
loc, llvmI8PtrTy, constPackBasePtr, ValueRange({offset}));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Some frequently used types.
|
|
||||||
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
|
||||||
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
|
||||||
|
|
||||||
// Allocate the memory where the constants will be used from.
|
|
||||||
// This is a region of local memory and needs to be emitted as an alloca.
|
|
||||||
auto one = rewriter.create<LLVM::ConstantOp>(
|
|
||||||
loc, llvmI64Ty, rewriter.getI64IntegerAttr(1));
|
|
||||||
auto alloc = rewriter.create<LLVM::AllocaOp>(
|
|
||||||
loc, llvmGlobalType.getPointerTo(), one, /*alignment=*/0);
|
|
||||||
|
|
||||||
// Copy constant value into the local alloca:
|
|
||||||
// - Bitcast alloc to i8*
|
|
||||||
Value int8PtrAlloc =
|
|
||||||
rewriter.create<LLVM::BitcastOp>(loc, llvmI8PtrTy, alloc);
|
|
||||||
// - Bitcast global to i8*
|
|
||||||
Value globalValue = rewriter.create<LLVM::AddressOfOp>(loc, global);
|
|
||||||
Value i8PtrGlobal =
|
|
||||||
rewriter.create<LLVM::BitcastOp>(loc, llvmI8PtrTy, globalValue);
|
|
||||||
// - Set size.
|
|
||||||
Value memRefElementSize = rewriter.create<LLVM::ConstantOp>(loc, llvmI64Ty,
|
|
||||||
rewriter.getI64IntegerAttr(getMemRefEltSizeInBytes(memRefTy)));
|
|
||||||
Value numElementsValue = rewriter.create<LLVM::ConstantOp>(
|
|
||||||
loc, llvmI64Ty, rewriter.getI64IntegerAttr(numElements));
|
|
||||||
Value totalElementsSize =
|
|
||||||
rewriter.create<LLVM::MulOp>(loc, memRefElementSize, numElementsValue);
|
|
||||||
Value int64Size =
|
|
||||||
rewriter.create<LLVM::SExtOp>(loc, llvmI64Ty, totalElementsSize);
|
|
||||||
// - Set volatile.
|
|
||||||
Value isVolatile = rewriter.create<LLVM::ConstantOp>(loc,
|
|
||||||
LLVM::LLVMType::getInt1Ty(llvmDialect),
|
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
|
|
||||||
// - Copy constant data into the alloca.
|
|
||||||
auto memcpyRef = getOrInsertMemcpy(rewriter, module, llvmDialect);
|
|
||||||
rewriter.create<CallOp>(loc, memcpyRef,
|
|
||||||
LLVM::LLVMType::getVoidTy(llvmDialect),
|
|
||||||
ArrayRef<Value>({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile}));
|
|
||||||
|
|
||||||
// Prepare data to be inserted into MemRef.
|
// Prepare data to be inserted into MemRef.
|
||||||
auto llvmConstantElementType = constantElementType.cast<LLVM::LLVMType>();
|
auto llvmConstantElementType = constantElementType.cast<LLVM::LLVMType>();
|
||||||
Value typedAlloc = rewriter.create<LLVM::BitcastOp>(
|
Value typedAlloc = rewriter.create<LLVM::BitcastOp>(
|
||||||
|
@ -681,6 +705,115 @@ private:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// KRNL to LLVM: KrnlPackedConstOpLowering
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class KrnlPackedConstOpLowering : public ConvertToLLVMPattern {
|
||||||
|
public:
|
||||||
|
explicit KrnlPackedConstOpLowering(
|
||||||
|
MLIRContext *context, LLVMTypeConverter &lowering_)
|
||||||
|
: ConvertToLLVMPattern(
|
||||||
|
KrnlPackedConstantOp::getOperationName(), context, lowering_) {}
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto *context = op->getContext();
|
||||||
|
ModuleOp module = op->getParentOfType<ModuleOp>();
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
|
auto *llvmDialect =
|
||||||
|
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||||
|
assert(llvmDialect && "expected llvm dialect to be registered");
|
||||||
|
|
||||||
|
auto packedConstOp = llvm::dyn_cast<KrnlPackedConstantOp>(op);
|
||||||
|
LLVM::GlobalOp globalBase;
|
||||||
|
// Some frequently used types.
|
||||||
|
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
||||||
|
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||||
|
{
|
||||||
|
OpBuilder::InsertionGuard insertGuard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(module.getBody());
|
||||||
|
|
||||||
|
globalBase = rewriter.create<LLVM::GlobalOp>(loc, llvmI8PtrTy,
|
||||||
|
/*isConstant=*/false, LLVM::Linkage::Internal, "packedConst",
|
||||||
|
nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto mainFunc = module.lookupSymbol<FuncOp>("main_graph");
|
||||||
|
assert(mainFunc);
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(
|
||||||
|
&mainFunc.getBody().front(), mainFunc.getBody().front().begin());
|
||||||
|
|
||||||
|
// - Initialize the global constant base.
|
||||||
|
Value basePtrAddr = rewriter.create<LLVM::AddressOfOp>(loc, globalBase);
|
||||||
|
auto getEmbeddedConstPoolRef = getOrInsertExternFunc(
|
||||||
|
KrnlPackedConstantOp::getEmbeddedDataLoaderMethodName(), module,
|
||||||
|
LLVM::LLVMType::getFunctionTy(
|
||||||
|
llvmI8PtrTy, {llvmI64Ty}, /*isVarArg=*/false),
|
||||||
|
rewriter);
|
||||||
|
auto constPackSize = rewriter.create<LLVM::ConstantOp>(loc,
|
||||||
|
LLVM::LLVMType::getInt64Ty(llvmDialect),
|
||||||
|
packedConstOp.size_in_bytesAttr());
|
||||||
|
Value alloc = rewriter
|
||||||
|
.create<CallOp>(loc, getEmbeddedConstPoolRef, llvmI8PtrTy,
|
||||||
|
ArrayRef<Value>({constPackSize}))
|
||||||
|
.getResult(0);
|
||||||
|
rewriter.create<LLVM::StoreOp>(loc, alloc, basePtrAddr);
|
||||||
|
{
|
||||||
|
OpBuilder::InsertionGuard insertGuard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(module.getBody());
|
||||||
|
|
||||||
|
// Record constant pack *file path* as a global variable (by recording the
|
||||||
|
// file path string's underlying char array + its length).
|
||||||
|
const auto &fileNameAttr = packedConstOp.file_nameAttr();
|
||||||
|
auto type =
|
||||||
|
LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(llvmDialect),
|
||||||
|
fileNameAttr.getValue().size());
|
||||||
|
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
|
||||||
|
LLVM::Linkage::External,
|
||||||
|
mlir::KrnlPackedConstantOp::getConstPackFilePathSymbolName(),
|
||||||
|
fileNameAttr);
|
||||||
|
type = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||||
|
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
|
||||||
|
LLVM::Linkage::External,
|
||||||
|
mlir::KrnlPackedConstantOp::getConstPackFilePathStrLenSymbolName(),
|
||||||
|
rewriter.getI64IntegerAttr(fileNameAttr.getValue().size()));
|
||||||
|
|
||||||
|
// Record constant pack *file name* as a global variable (by recording the
|
||||||
|
// file name string's underlying char array + its length).
|
||||||
|
auto constPackFileName =
|
||||||
|
llvm::sys::path::filename(fileNameAttr.getValue());
|
||||||
|
type = LLVM::LLVMType::getArrayTy(
|
||||||
|
LLVM::LLVMType::getInt8Ty(llvmDialect), constPackFileName.size());
|
||||||
|
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
|
||||||
|
LLVM::Linkage::External,
|
||||||
|
mlir::KrnlPackedConstantOp::getConstPackFileNameSymbolName(),
|
||||||
|
rewriter.getStringAttr(constPackFileName));
|
||||||
|
type = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||||
|
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
|
||||||
|
LLVM::Linkage::External,
|
||||||
|
mlir::KrnlPackedConstantOp::getConstPackFileNameStrLenSymbolName(),
|
||||||
|
rewriter.getI64IntegerAttr(constPackFileName.size()));
|
||||||
|
|
||||||
|
type = LLVM::LLVMType::getInt8Ty(llvmDialect);
|
||||||
|
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
|
||||||
|
LLVM::Linkage::External,
|
||||||
|
mlir::KrnlPackedConstantOp::getConstPackIsLESymbolName(),
|
||||||
|
rewriter.getI8IntegerAttr(packedConstOp.is_le()));
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
static int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
|
||||||
|
return (a.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||||
|
}
|
||||||
|
};
|
||||||
} // end namespace
|
} // end namespace
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -712,7 +845,8 @@ void KrnlToLLVMLoweringPass::runOnOperation() {
|
||||||
/*emitCWrapperS=*/true,
|
/*emitCWrapperS=*/true,
|
||||||
/*useAlignedAlloc=*/false);
|
/*useAlignedAlloc=*/false);
|
||||||
|
|
||||||
patterns.insert<KrnlGlobalOpLowering>(&getContext(), typeConverter);
|
patterns.insert<KrnlGlobalOpLowering, KrnlPackedConstOpLowering>(
|
||||||
|
&getContext(), typeConverter);
|
||||||
patterns.insert<KrnlGetRefOpLowering>(&getContext(), typeConverter);
|
patterns.insert<KrnlGetRefOpLowering>(&getContext(), typeConverter);
|
||||||
|
|
||||||
// Lower from the `krnl` dialect i.e. the Reshape operation.
|
// Lower from the `krnl` dialect i.e. the Reshape operation.
|
||||||
|
@ -722,8 +856,9 @@ void KrnlToLLVMLoweringPass::runOnOperation() {
|
||||||
// We want to completely lower to LLVM, so we use a `FullConversion`. This
|
// We want to completely lower to LLVM, so we use a `FullConversion`. This
|
||||||
// ensures that only legal operations will remain after the conversion.
|
// ensures that only legal operations will remain after the conversion.
|
||||||
if (failed(applyFullConversion(
|
if (failed(applyFullConversion(
|
||||||
getOperation(), target, patterns, &typeConverter)))
|
getOperation(), target, patterns, &typeConverter))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create the pass for lowering `Krnl`, `Affine` and `Std` dialects to LLVM.
|
/// Create the pass for lowering `Krnl`, `Affine` and `Std` dialects to LLVM.
|
||||||
|
|
|
@ -50,8 +50,9 @@ public:
|
||||||
|
|
||||||
int64_t dynamicOperations = 0;
|
int64_t dynamicOperations = 0;
|
||||||
f.walk([&](mlir::Operation *op) {
|
f.walk([&](mlir::Operation *op) {
|
||||||
if (returnsDynamicShape(op))
|
if (returnsDynamicShape(op)) {
|
||||||
dynamicOperations++;
|
dynamicOperations++;
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// If any dynamic operations remain, this indicates a failure.
|
// If any dynamic operations remain, this indicates a failure.
|
||||||
|
|
|
@ -0,0 +1,127 @@
|
||||||
|
//===- ElideKrnlGlobalConstants.cpp - Krnl Constant lobal Value Elision ---===//
|
||||||
|
//
|
||||||
|
// Copyright 2019-2020 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// In practice, the constant values of Global Krnl operations may be large
|
||||||
|
// enough to hinder the readability of the MLIR intermediate representation.
|
||||||
|
//
|
||||||
|
// This file creates a pass which elides the explicit values of constant
|
||||||
|
// global operations. This pass has purely cosmetic purposes and should only be
|
||||||
|
// run to obtain a compact representation of the program when emitting Krnl
|
||||||
|
// dialect code. This pass should never be invoked on code meant to be run.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "llvm/Support/FileSystem.h"
|
||||||
|
|
||||||
|
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
||||||
|
#include "src/Pass/Passes.hpp"
|
||||||
|
#include "src/Transform/ElideKrnlGlobalConstants.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* Function pass that performs constant value elision of Krnl globals.
|
||||||
|
*/
|
||||||
|
class PackKrnlGlobalConstantsPass
|
||||||
|
: public PassWrapper<PackKrnlGlobalConstantsPass, OperationPass<ModuleOp>> {
|
||||||
|
public:
|
||||||
|
/// Make sure that we have a valid default constructor and copy constructor to
|
||||||
|
/// make sure that the options are initialized properly.
|
||||||
|
PackKrnlGlobalConstantsPass() = default;
|
||||||
|
PackKrnlGlobalConstantsPass(const PackKrnlGlobalConstantsPass &pass) {}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
auto module = getOperation();
|
||||||
|
OpBuilder builder(&getContext());
|
||||||
|
|
||||||
|
// Packing constant arrays to packedConst.
|
||||||
|
std::vector<char> packedConst;
|
||||||
|
module.walk([&](KrnlGlobalOp op) {
|
||||||
|
assert(op.value());
|
||||||
|
op.offsetAttr(builder.getI64IntegerAttr(packedConst.size()));
|
||||||
|
assert(op.value()->isa<DenseElementsAttr>());
|
||||||
|
const auto &denseAttr = op.valueAttr().cast<DenseElementsAttr>();
|
||||||
|
auto numElements = denseAttr.getNumElements();
|
||||||
|
if (numElements <= elisionThreshold)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// TODO(tjingrant) verify we can actually use the raw data.
|
||||||
|
std::vector<char> rawData = denseAttr.getRawData();
|
||||||
|
packedConst.insert(packedConst.end(), rawData.begin(), rawData.end());
|
||||||
|
});
|
||||||
|
|
||||||
|
// Remove value attributes from krnl constant op.
|
||||||
|
ConversionTarget target(getContext());
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
patterns.insert<KrnlConstGlobalValueElision>(
|
||||||
|
&getContext(), elisionThreshold);
|
||||||
|
// Apply constant value elision.
|
||||||
|
module.walk(
|
||||||
|
[&](FuncOp func) { applyPatternsAndFoldGreedily(func, patterns); });
|
||||||
|
|
||||||
|
bool isLE = llvm::support::endian::system_endianness() ==
|
||||||
|
llvm::support::endianness::little;
|
||||||
|
mlir::OperationState state(module.getLoc(), "krnl.packed_const");
|
||||||
|
KrnlPackedConstantOp::build(builder, state,
|
||||||
|
builder.getIntegerType(/*width=*/64),
|
||||||
|
/*size_in_bytes=*/builder.getI64IntegerAttr(packedConst.size()),
|
||||||
|
/*is_le=*/builder.getBoolAttr(isLE),
|
||||||
|
/*value=*/nullptr,
|
||||||
|
/*file_name=*/nullptr);
|
||||||
|
auto packedConstOp =
|
||||||
|
llvm::cast<mlir::KrnlPackedConstantOp>(mlir::Operation::create(state));
|
||||||
|
module.insert(module.begin(), packedConstOp);
|
||||||
|
if (moveToFile) {
|
||||||
|
std::string pathStr;
|
||||||
|
if (filename.hasValue()) {
|
||||||
|
pathStr = filename.getValue();
|
||||||
|
} else {
|
||||||
|
llvm::SmallVector<char, 10> path;
|
||||||
|
llvm::sys::fs::createTemporaryFile("packed_const", "tmp", path);
|
||||||
|
pathStr = std::string(path.begin(), path.end());
|
||||||
|
}
|
||||||
|
packedConstOp.file_nameAttr(builder.getStringAttr(pathStr));
|
||||||
|
std::ofstream outfile(pathStr, std::ofstream::binary);
|
||||||
|
outfile.write(packedConst.data(), packedConst.size());
|
||||||
|
} else {
|
||||||
|
auto shapeTy =
|
||||||
|
RankedTensorType::get({static_cast<int64_t>(packedConst.size())},
|
||||||
|
builder.getIntegerType(8));
|
||||||
|
auto denseAttr =
|
||||||
|
DenseIntElementsAttr::get(shapeTy, llvm::makeArrayRef(packedConst));
|
||||||
|
packedConstOp.valueAttr(denseAttr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Option<bool> moveToFile{*this, "move-to-file",
|
||||||
|
llvm::cl::desc("Whether to move the packed constant to a file."),
|
||||||
|
llvm::cl::init(true)};
|
||||||
|
Option<int64_t> elisionThreshold{*this, "elision-threshold",
|
||||||
|
llvm::cl::desc(
|
||||||
|
"A threshold value specifying the maximum number of elements a "
|
||||||
|
"constant operation can hold as an attribute. If the number exceeds "
|
||||||
|
"this threshold, constants will be packed together and, in the case "
|
||||||
|
"where `move-to-file` option is enabled, stored as a binary file on "
|
||||||
|
"disk. This can help preserve readability of IR dump and improve "
|
||||||
|
"compilation speed."),
|
||||||
|
llvm::cl::init(KrnlConstGlobalValueElision::kDefaultElisionThreshold)};
|
||||||
|
Option<std::string> filename{*this, "filename",
|
||||||
|
llvm::cl::desc(
|
||||||
|
"Specify a file in which the packed constant is to be stored.")};
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> mlir::createPackKrnlGlobalConstantsPass() {
|
||||||
|
return std::make_unique<PackKrnlGlobalConstantsPass>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<PackKrnlGlobalConstantsPass> pass("pack-krnl-constants",
|
||||||
|
"Elide the constant values of the Global Krnl operations.");
|
|
@ -6,6 +6,7 @@ from __future__ import unicode_literals
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
|
import warnings
|
||||||
import onnx.backend.base
|
import onnx.backend.base
|
||||||
import onnx.backend.test
|
import onnx.backend.test
|
||||||
|
|
||||||
|
@ -29,7 +30,57 @@ from PyRuntime import ExecutionSession
|
||||||
def execute_commands(cmds):
|
def execute_commands(cmds):
|
||||||
if (VERBOSE):
|
if (VERBOSE):
|
||||||
print(" ".join(cmds))
|
print(" ".join(cmds))
|
||||||
subprocess.run(cmds, stdout=subprocess.PIPE)
|
subprocess.run(cmds)
|
||||||
|
|
||||||
|
|
||||||
|
# There are two issues, which necessitates the adoption of this endianness
|
||||||
|
# aware wrapper around Execution Session:
|
||||||
|
# 1. Input arrays are given sometimes in native byte order, sometime in
|
||||||
|
# LE byte order, and as soon as the python array enters into py::array
|
||||||
|
# C++ objects through pybind, we will no longer be able to query their
|
||||||
|
# endianness. So we must intercept the inputs and convert them into
|
||||||
|
# native endianness.
|
||||||
|
# 2. Output arrays are compared with reference outputs, the comparison
|
||||||
|
# unfortunately includes checking that our outputs and reference outputs
|
||||||
|
# share the same endianness. So we try to figure out what is the desired
|
||||||
|
# reference output endianness, and convert our outputs to this desired
|
||||||
|
# endianness.
|
||||||
|
class EndiannessAwareExecutionSession(ExecutionSession):
|
||||||
|
def __init__(self, path, entry_point):
|
||||||
|
super().__init__(path, entry_point)
|
||||||
|
|
||||||
|
def is_input_le(self, inputs):
|
||||||
|
inputs_endianness = list(map(lambda x: x.dtype.byteorder, inputs))
|
||||||
|
endianness_is_consistent = len(set(inputs_endianness)) <= 1
|
||||||
|
assert endianness_is_consistent, \
|
||||||
|
"Input arrays contain a mixture of endianness configuration."
|
||||||
|
|
||||||
|
sys_is_le = sys.byteorder == 'little'
|
||||||
|
# To interpret character symbols indicating endianness:
|
||||||
|
# https://numpy.org/doc/stable/reference/generated/numpy.dtype.byteorder.html
|
||||||
|
explicitly_le = inputs_endianness[0] == "<"
|
||||||
|
implicitly_le = (inputs_endianness[0] == "=" and sys_is_le)
|
||||||
|
return explicitly_le or implicitly_le
|
||||||
|
|
||||||
|
def run(self, inputs, **kwargs):
|
||||||
|
if len(inputs):
|
||||||
|
# Deduce desired endianness of output from inputs.
|
||||||
|
sys_is_le = sys.byteorder == 'little'
|
||||||
|
inp_is_le = self.is_input_le(inputs)
|
||||||
|
if (sys_is_le != inp_is_le):
|
||||||
|
inputs = list(
|
||||||
|
map(lambda x: x.byteswap().newbyteorder(), inputs))
|
||||||
|
outputs = super().run(inputs)
|
||||||
|
if (sys_is_le != inp_is_le):
|
||||||
|
outputs = list(
|
||||||
|
map(lambda x: x.byteswap().newbyteorder(), outputs))
|
||||||
|
return outputs
|
||||||
|
else:
|
||||||
|
# Can't deduce desired output endianess, fingers crossed.
|
||||||
|
warnings.warn(
|
||||||
|
"Cannot deduce desired output endianness, using native endianness by default."
|
||||||
|
)
|
||||||
|
return super().run(inputs)
|
||||||
|
|
||||||
|
|
||||||
class DummyBackend(onnx.backend.base.Backend):
|
class DummyBackend(onnx.backend.base.Backend):
|
||||||
|
@ -40,7 +91,8 @@ class DummyBackend(onnx.backend.base.Backend):
|
||||||
onnx.save(model, "temp_model.onnx")
|
onnx.save(model, "temp_model.onnx")
|
||||||
# Call frontend to process temp_model.onnx, bit code will be generated.
|
# Call frontend to process temp_model.onnx, bit code will be generated.
|
||||||
execute_commands([ONNX_MLIR, "temp_model.onnx"])
|
execute_commands([ONNX_MLIR, "temp_model.onnx"])
|
||||||
return ExecutionSession("./temp_model.so", "_dyn_entry_point_main_graph")
|
return EndiannessAwareExecutionSession("./temp_model.so",
|
||||||
|
"_dyn_entry_point_main_graph")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def supports_device(cls, device):
|
def supports_device(cls, device):
|
||||||
|
@ -80,7 +132,6 @@ test_to_enable = [
|
||||||
"test_concat_3d_axis_0_cpu",
|
"test_concat_3d_axis_0_cpu",
|
||||||
"test_concat_3d_axis_1_cpu",
|
"test_concat_3d_axis_1_cpu",
|
||||||
"test_concat_3d_axis_2_cpu",
|
"test_concat_3d_axis_2_cpu",
|
||||||
|
|
||||||
"test_concat_1d_axis_negative_1_cpu",
|
"test_concat_1d_axis_negative_1_cpu",
|
||||||
"test_concat_2d_axis_negative_1_cpu",
|
"test_concat_2d_axis_negative_1_cpu",
|
||||||
"test_concat_2d_axis_negative_2_cpu",
|
"test_concat_2d_axis_negative_2_cpu",
|
||||||
|
@ -359,12 +410,18 @@ test_to_enable = [
|
||||||
"test_split_variable_parts_1d_cpu",
|
"test_split_variable_parts_1d_cpu",
|
||||||
"test_split_variable_parts_2d_cpu",
|
"test_split_variable_parts_2d_cpu",
|
||||||
"test_split_variable_parts_default_axis_cpu",
|
"test_split_variable_parts_default_axis_cpu",
|
||||||
|
|
||||||
|
# ResNet
|
||||||
|
"test_resnet50_cpu",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# Extract name of all test cases.
|
# Extract name of all test cases.
|
||||||
import inspect
|
import inspect
|
||||||
all_tests = inspect.getmembers(
|
all_tests = []
|
||||||
|
all_tests += inspect.getmembers(
|
||||||
|
backend_test.test_cases["OnnxBackendRealModelTest"])
|
||||||
|
all_tests += inspect.getmembers(
|
||||||
backend_test.test_cases["OnnxBackendNodeModelTest"])
|
backend_test.test_cases["OnnxBackendNodeModelTest"])
|
||||||
all_test_names = list(map(lambda x: x[0], all_tests))
|
all_test_names = list(map(lambda x: x[0], all_tests))
|
||||||
|
|
||||||
|
@ -372,8 +429,7 @@ all_test_names = list(map(lambda x: x[0], all_tests))
|
||||||
for test_name in test_to_enable:
|
for test_name in test_to_enable:
|
||||||
assert test_name in all_test_names, """test name {} not found, it is likely
|
assert test_name in all_test_names, """test name {} not found, it is likely
|
||||||
that you may have misspelled the test name or the specified test does not
|
that you may have misspelled the test name or the specified test does not
|
||||||
exist in the version of onnx package you installed.""".format(
|
exist in the version of onnx package you installed.""".format(test_name)
|
||||||
test_name)
|
|
||||||
backend_test.include(r"^{}$".format(test_name))
|
backend_test.include(r"^{}$".format(test_name))
|
||||||
|
|
||||||
# import all test cases at global scope to make them visible to python.unittest
|
# import all test cases at global scope to make them visible to python.unittest
|
||||||
|
|
|
@ -6,7 +6,7 @@ configure_lit_site_cfg(${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
|
||||||
MAIN_CONFIG
|
MAIN_CONFIG
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py)
|
${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py)
|
||||||
|
|
||||||
set(ONNX_MLIR_TEST_DEPENDS onnx-mlir-opt)
|
set(ONNX_MLIR_TEST_DEPENDS onnx-mlir-opt binary-decoder)
|
||||||
|
|
||||||
add_lit_testsuite(check-onnx-lit
|
add_lit_testsuite(check-onnx-lit
|
||||||
"Running the ONNX MLIR regression tests"
|
"Running the ONNX MLIR regression tests"
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
// RUN: onnx-mlir-opt --pack-krnl-constants='elision-threshold=3 move-to-file=true filename=test1.bin' %s -split-input-file && binary-decoder test1.bin -s 0 -n 16 --onnx::TensorProto::FLOAT -rm | FileCheck %s -check-prefix=BINARY_DECODER_1
|
||||||
|
// RUN: onnx-mlir-opt --pack-krnl-constants='elision-threshold=3 move-to-file=true filename=test2.bin' %s -split-input-file && binary-decoder test2.bin -s 16 -n 32 --onnx::TensorProto::INT32 -rm | FileCheck %s -check-prefix=BINARY_DECODER_2
|
||||||
|
|
||||||
|
// BINARY_DECODER_1: 0.1 0.2 0.3 0.4
|
||||||
|
// BINARY_DECODER_2: 1 2 3 4
|
||||||
|
func @test_krnl_const_packing_file_mixing_types() -> memref<1x4xf32> {
|
||||||
|
%0 = "krnl.global"() {name = "constant_0", shape = [1, 4], value = dense<[[0.1, 0.2, 0.3, 0.4]]> : tensor<1x4xf32>} : () -> memref<1x4xf32>
|
||||||
|
%1 = "krnl.global"() {name = "constant_1", shape = [1, 4], value = dense<[[1, 2, 3, 4]]> : tensor<1x4xi32>} : () -> memref<1x4xi32>
|
||||||
|
return %0 : memref<1x4xf32>
|
||||||
|
}
|
|
@ -0,0 +1,8 @@
|
||||||
|
import sys
|
||||||
|
|
||||||
|
root = config.root
|
||||||
|
|
||||||
|
if sys.byteorder == "little":
|
||||||
|
config.unsupported = True
|
||||||
|
else:
|
||||||
|
config.unsupported = False
|
|
@ -0,0 +1,11 @@
|
||||||
|
// RUN: onnx-mlir-opt --pack-krnl-constants='elision-threshold=3 move-to-file=false' %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK: [[CONST_PACK:%.+]] = "krnl.packed_const"() {is_le = false, size_in_bytes = 32 : i64, value = dense<[0, 0, 0, 0, 63, -128, 0, 0, 64, 0, 0, 0, 64, 64, 0, 0, 0, 0, 0, 0, 63, -128, 0, 0, 64, 0, 0, 0, 64, 64, 0, 0]> : tensor<32xi8>} : () -> i64
|
||||||
|
// CHECK-LABEL: func @test_krnl_const_packing() -> memref<1x4xf32> {
|
||||||
|
// CHECK-NEXT: [[CONST0:%.+]] = "krnl.global"() {name = "constant_0", offset = 0 : i64, shape = [1, 4]} : () -> memref<1x4xf32>
|
||||||
|
// CHECK-NEXT: [[CONST1:%.+]] = "krnl.global"() {name = "constant_1", offset = 16 : i64, shape = [1, 4]} : () -> memref<1x4xf32>
|
||||||
|
func @test_krnl_const_packing() -> memref<1x4xf32> {
|
||||||
|
%0 = "krnl.global"() {name = "constant_0", shape = [1, 4], value = dense<[[0., 1., 2., 3.]]> : tensor<1x4xf32>} : () -> memref<1x4xf32>
|
||||||
|
%1 = "krnl.global"() {name = "constant_1", shape = [1, 4], value = dense<[[0., 1., 2., 3.]]> : tensor<1x4xf32>} : () -> memref<1x4xf32>
|
||||||
|
return %0 : memref<1x4xf32>
|
||||||
|
}
|
|
@ -0,0 +1,8 @@
|
||||||
|
import sys
|
||||||
|
|
||||||
|
root = config.root
|
||||||
|
|
||||||
|
if sys.byteorder == "little":
|
||||||
|
config.unsupported = False
|
||||||
|
else:
|
||||||
|
config.unsupported = True
|
|
@ -0,0 +1,11 @@
|
||||||
|
// RUN: onnx-mlir-opt --pack-krnl-constants='elision-threshold=3 move-to-file=false' %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK: [[CONST_PACK:%.+]] = "krnl.packed_const"() {is_le = true, size_in_bytes = 32 : i64, value = dense<[0, 0, 0, 0, 0, 0, -128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 0, 0, 0, 0, -128, 63, 0, 0, 0, 64, 0, 0, 64, 64]> : tensor<32xi8>} : () -> i64
|
||||||
|
// CHECK-LABEL: func @test_krnl_const_packing() -> memref<1x4xf32> {
|
||||||
|
// CHECK-NEXT: [[CONST0:%.+]] = "krnl.global"() {name = "constant_0", offset = 0 : i64, shape = [1, 4]} : () -> memref<1x4xf32>
|
||||||
|
// CHECK-NEXT: [[CONST1:%.+]] = "krnl.global"() {name = "constant_1", offset = 16 : i64, shape = [1, 4]} : () -> memref<1x4xf32>
|
||||||
|
func @test_krnl_const_packing() -> memref<1x4xf32> {
|
||||||
|
%0 = "krnl.global"() {name = "constant_0", shape = [1, 4], value = dense<[[0., 1., 2., 3.]]> : tensor<1x4xf32>} : () -> memref<1x4xf32>
|
||||||
|
%1 = "krnl.global"() {name = "constant_1", shape = [1, 4], value = dense<[[0., 1., 2., 3.]]> : tensor<1x4xf32>} : () -> memref<1x4xf32>
|
||||||
|
return %0 : memref<1x4xf32>
|
||||||
|
}
|
|
@ -0,0 +1,8 @@
|
||||||
|
// RUN: onnx-mlir-opt --pack-krnl-constants='elision-threshold=3 move-to-file=true filename=test-pack-consts-to-file-same-type.bin' %s -split-input-file && binary-decoder test-pack-consts-to-file-same-type.bin -s 0 -n 32 --onnx::TensorProto::FLOAT -rm | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK: 0 1 2 3 0 1 2 3
|
||||||
|
func @test_krnl_const_packing_file() -> memref<1x4xf32> {
|
||||||
|
%0 = "krnl.global"() {name = "constant_0", shape = [1, 4], value = dense<[[0., 1., 2., 3.]]> : tensor<1x4xf32>} : () -> memref<1x4xf32>
|
||||||
|
%1 = "krnl.global"() {name = "constant_1", shape = [1, 4], value = dense<[[0., 1., 2., 3.]]> : tensor<1x4xf32>} : () -> memref<1x4xf32>
|
||||||
|
return %0 : memref<1x4xf32>
|
||||||
|
}
|
|
@ -31,7 +31,7 @@ tool_dirs = [
|
||||||
config.onnx_mlir_tools_dir, config.mlir_tools_dir, config.llvm_tools_dir
|
config.onnx_mlir_tools_dir, config.mlir_tools_dir, config.llvm_tools_dir
|
||||||
]
|
]
|
||||||
tool_names = [
|
tool_names = [
|
||||||
'onnx-mlir-opt', 'mlir-opt', 'mlir-translate'
|
'onnx-mlir-opt', 'mlir-opt', 'mlir-translate', "binary-decoder"
|
||||||
]
|
]
|
||||||
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
|
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
|
||||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
import lit.llvm
|
import lit.llvm
|
||||||
|
|
||||||
if '@BUILD_SHARED_LIBS@' == 'ON':
|
if '@BUILD_SHARED_LIBS@' == 'ON':
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
// RUN: onnx-mlir-opt --elide-krnl-constants %s -split-input-file | FileCheck %s
|
// RUN: onnx-mlir-opt --elide-krnl-constants %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x10xf32>
|
// CHECK-LABEL: func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x70xf32>
|
||||||
func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x10xf32> {
|
func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x70xf32> {
|
||||||
%0 = "krnl.global"() {name = "constant_0", shape = [1, 10], value = dense<[[-0.0448560268, 0.00779166119, 0.0681008175, 0.0299937408, -0.126409635, 0.14021875, -0.0552849025, -0.0493838154, 0.0843220502, -0.0545404144]]> : tensor<1x10xf32>} : () -> memref<1x10xf32>
|
%0 = "krnl.global"() {name = "constant_0", shape = [1, 70], value = dense<[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]> : tensor<1x70xf32>} : () -> memref<1x70xf32>
|
||||||
return %0 : memref<1x10xf32>
|
return %0 : memref<1x70xf32>
|
||||||
|
|
||||||
// CHECK: %0 = "krnl.global"() {name = "constant_0", shape = [1, 10]} : () -> memref<1x10xf32>
|
// CHECK: {{.*}} = "krnl.global"() {name = "constant_0", shape = [1, 70]} : () -> memref<1x70xf32>
|
||||||
// CHECK: return %0 : memref<1x10xf32>
|
// CHECK: return {{.*}} : memref<1x70xf32>
|
||||||
}
|
}
|
|
@ -37,7 +37,7 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H,
|
||||||
llvm::SmallVector<Type, 1> outputsType{yType};
|
llvm::SmallVector<Type, 1> outputsType{yType};
|
||||||
|
|
||||||
auto funcType = builder.getFunctionType(inputsType, outputsType);
|
auto funcType = builder.getFunctionType(inputsType, outputsType);
|
||||||
string funcName = "test_conv";
|
string funcName = "main_graph";
|
||||||
llvm::SmallVector<NamedAttribute, 1> attrs;
|
llvm::SmallVector<NamedAttribute, 1> attrs;
|
||||||
auto funcOp =
|
auto funcOp =
|
||||||
builder.create<FuncOp>(UnknownLoc::get(&ctx), funcName, funcType, attrs);
|
builder.create<FuncOp>(UnknownLoc::get(&ctx), funcName, funcType, attrs);
|
||||||
|
@ -88,13 +88,13 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H,
|
||||||
OwningModuleRef moduleRef(module);
|
OwningModuleRef moduleRef(module);
|
||||||
|
|
||||||
llvm::SmallVector<char, 10> path;
|
llvm::SmallVector<char, 10> path;
|
||||||
llvm::sys::fs::createTemporaryFile("_test_conv", "", path);
|
llvm::sys::fs::createTemporaryFile("_main_graph", "", path);
|
||||||
string pathStr(path.begin(), path.end());
|
string pathStr(path.begin(), path.end());
|
||||||
llvm::FileRemover remover(path);
|
llvm::FileRemover remover(path);
|
||||||
|
|
||||||
compileModule(moduleRef, ctx, pathStr, EmitLib);
|
compileModule(moduleRef, ctx, pathStr, EmitLib);
|
||||||
onnx_mlir::ExecutionSession sess(
|
onnx_mlir::ExecutionSession sess(
|
||||||
pathStr + ".so", "_dyn_entry_point_test_conv");
|
pathStr + ".so", "_dyn_entry_point_main_graph");
|
||||||
|
|
||||||
std::vector<unique_ptr<DynMemRef>> inputs;
|
std::vector<unique_ptr<DynMemRef>> inputs;
|
||||||
auto xDmr = unique_ptr<DynMemRef>(getRndRealDmr<float>({N, C, H, W}));
|
auto xDmr = unique_ptr<DynMemRef>(getRndRealDmr<float>({N, C, H, W}));
|
||||||
|
|
Loading…
Reference in New Issue