Enable ONNX Backend Test (#1)
* wip, commit before merging with upstream * organize API, return wrapped output * enable onnx backend test * undo unintentional commit * fix krnl ops tablegen * format krnl ops * reorder fillDynMemRefWithMemRef to be after fillPtrToMemRefWithDynMemRef, better comments * more onnx backend tests * ensure that test names refer to existing tests * improve code readability by shortening type names * nit * restore unintentional changes * more nits * fix ; -> : * split runtime implementation into header and body file, add support for data types * comment on the onnx backend test * make the comments read better * do not dump when lowering
This commit is contained in:
parent
5573cb39fe
commit
685bf23b40
|
@ -26,6 +26,7 @@ add_subdirectory(third_party/pybind11)
|
||||||
set(CMAKE_CXX_STANDARD 14)
|
set(CMAKE_CXX_STANDARD 14)
|
||||||
add_subdirectory(src/builder)
|
add_subdirectory(src/builder)
|
||||||
add_subdirectory(src/compiler)
|
add_subdirectory(src/compiler)
|
||||||
|
add_subdirectory(src/runtime)
|
||||||
add_subdirectory(src)
|
add_subdirectory(src)
|
||||||
|
|
||||||
add_subdirectory(test)
|
add_subdirectory(test)
|
||||||
|
|
|
@ -34,7 +34,7 @@ set(MLIR_SRC_INCLUDE_PATH ${LLVM_SRC}/projects/mlir/include)
|
||||||
set(MLIR_BIN_INCLUDE_PATH ${LLVM_BUILD}/projects/mlir/include)
|
set(MLIR_BIN_INCLUDE_PATH ${LLVM_BUILD}/projects/mlir/include)
|
||||||
set(MLIR_TOOLS_DIR ${LLVM_BUILD}/bin)
|
set(MLIR_TOOLS_DIR ${LLVM_BUILD}/bin)
|
||||||
|
|
||||||
set(ONNF_TOOLS_DIR ${ONNF_BIN_ROOT}/src/compiler/tool/onnf_opt)
|
set(ONNF_TOOLS_DIR ${ONNF_BIN_ROOT}/bin)
|
||||||
set(ONNF_LIT_TEST_SRC_DIR ${CMAKE_SOURCE_DIR}/test/mlir)
|
set(ONNF_LIT_TEST_SRC_DIR ${CMAKE_SOURCE_DIR}/test/mlir)
|
||||||
set(ONNF_LIT_TEST_BUILD_DIR ${CMAKE_BINARY_DIR}/test/mlir)
|
set(ONNF_LIT_TEST_BUILD_DIR ${CMAKE_BINARY_DIR}/test/mlir)
|
||||||
|
|
||||||
|
|
|
@ -614,7 +614,6 @@ private:
|
||||||
onnx::NodeProto node, int nIn, int nOut,
|
onnx::NodeProto node, int nIn, int nOut,
|
||||||
std::initializer_list<std::tuple<std::string, std::string, std::string>>
|
std::initializer_list<std::tuple<std::string, std::string, std::string>>
|
||||||
attrs) {
|
attrs) {
|
||||||
|
|
||||||
// Conv has attribute dilations, kernel_shape, pads, the default value of
|
// Conv has attribute dilations, kernel_shape, pads, the default value of
|
||||||
// which is determined by the shape of first argument. However, since the
|
// which is determined by the shape of first argument. However, since the
|
||||||
// shape is unknown now, these attributes can be not generated auto
|
// shape is unknown now, these attributes can be not generated auto
|
||||||
|
@ -686,7 +685,7 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
void ImportGraph(const onnx::GraphProto &graph,
|
void ImportGraph(const onnx::GraphProto &graph,
|
||||||
const std::string &name = "main") {
|
const std::string &name = "main_graph") {
|
||||||
// create a function for the graph
|
// create a function for the graph
|
||||||
// TODO:
|
// TODO:
|
||||||
// * get name and type for the function.
|
// * get name and type for the function.
|
||||||
|
@ -699,13 +698,18 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: import the initializer
|
// TODO: import the initializer
|
||||||
auto func_type = builder_.getFunctionType(arg_types, {});
|
auto funcType = builder_.getFunctionType(arg_types, {});
|
||||||
auto main_func =
|
auto mainFunc =
|
||||||
mlir::FuncOp::create(UnknownLoc(), name, func_type, /* attrs = */ {});
|
mlir::FuncOp::create(UnknownLoc(), name, funcType, /* attrs = */ {});
|
||||||
auto &entryBlock = *main_func.addEntryBlock();
|
auto entryPoint = mlir::ONNXEntryPointOp::create(
|
||||||
|
UnknownLoc(), mainFunc, /*numInputs=*/graph.input().size(),
|
||||||
|
/*numOutputs=*/graph.output().size());
|
||||||
|
|
||||||
|
auto &entryBlock = *mainFunc.addEntryBlock();
|
||||||
builder_.setInsertionPointToStart(&entryBlock);
|
builder_.setInsertionPointToStart(&entryBlock);
|
||||||
module_.push_back(main_func);
|
|
||||||
|
module_.push_back(mainFunc);
|
||||||
|
module_.push_back(entryPoint);
|
||||||
|
|
||||||
for (auto it : llvm::zip(graph.input(), entryBlock.getArguments())) {
|
for (auto it : llvm::zip(graph.input(), entryBlock.getArguments())) {
|
||||||
ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it));
|
ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it));
|
||||||
|
@ -728,8 +732,8 @@ private:
|
||||||
builder_.create<mlir::ReturnOp>(UnknownLoc(), ret_vals);
|
builder_.create<mlir::ReturnOp>(UnknownLoc(), ret_vals);
|
||||||
// Update main function signature to reflect types of newly imported
|
// Update main function signature to reflect types of newly imported
|
||||||
// output tensors.
|
// output tensors.
|
||||||
func_type = builder_.getFunctionType(arg_types, ret_types);
|
funcType = builder_.getFunctionType(arg_types, ret_types);
|
||||||
main_func.setType(func_type);
|
mainFunc.setType(funcType);
|
||||||
}
|
}
|
||||||
}; // FrontendGenImpl class
|
}; // FrontendGenImpl class
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class MLIRContext;
|
class MLIRContext;
|
||||||
class OwningModuleRef;
|
class OwningModuleRef;
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Import a model into one of ONNF's frontend models.
|
// Import a model into one of ONNF's frontend models.
|
||||||
|
@ -41,7 +41,8 @@ mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model);
|
||||||
* @return MLIR::module generated for the ONNX model.
|
* @return MLIR::module generated for the ONNX model.
|
||||||
*/
|
*/
|
||||||
void ImportFrontendModelFile(std::string model_fname,
|
void ImportFrontendModelFile(std::string model_fname,
|
||||||
mlir::MLIRContext& context, mlir::OwningModuleRef& module);
|
mlir::MLIRContext &context,
|
||||||
|
mlir::OwningModuleRef &module);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* TODO: Import models into other extension dialects that cover the
|
* TODO: Import models into other extension dialects that cover the
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
add_library(DLCAnalysis STATIC
|
||||||
|
extract_integer_set.cpp)
|
||||||
|
|
||||||
|
target_include_directories(DLCAnalysis PRIVATE ${DLC_SRC_ROOT})
|
||||||
|
target_include_directories(DLCAnalysis PRIVATE ${DLC_BIN_ROOT})
|
|
@ -364,6 +364,14 @@ ParseResult parseKrnlReturnLoopsOp(OpAsmParser &parser,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void KrnlEntryPointOp::build(mlir::Builder *builder, OperationState &state,
|
||||||
|
SymbolRefAttr funcAttr, IntegerAttr numInputs,
|
||||||
|
IntegerAttr numOutputs) {
|
||||||
|
state.addAttribute(KrnlEntryPointOp::getEntryPointFuncAttrName(), funcAttr);
|
||||||
|
state.addAttribute(KrnlEntryPointOp::getNumInputsAttrName(), numInputs);
|
||||||
|
state.addAttribute(KrnlEntryPointOp::getNumOutputsAttrName(), numOutputs);
|
||||||
|
}
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "src/compiler/krnl.cpp.inc"
|
#include "src/compiler/krnl.cpp.inc"
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
#include "mlir/IR/DialectImplementation.h"
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
|
#include "mlir/IR/Function.h"
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
|
||||||
|
|
|
@ -8,35 +8,27 @@
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
|
|
||||||
|
|
||||||
def Krnl_Dialect : Dialect {
|
def Krnl_Dialect : Dialect {
|
||||||
let name = "krnl";
|
let name = "krnl";
|
||||||
let cppNamespace = "";
|
let cppNamespace = "";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Require regions to have krnl.terminate terminator operation.
|
// Require regions to have krnl.terminate terminator operation.
|
||||||
def ImplicitKrnlTerminator
|
def ImplicitKrnlTerminator : SingleBlockImplicitTerminator<"KrnlTerminatorOp">;
|
||||||
: SingleBlockImplicitTerminator<"KrnlTerminatorOp">;
|
|
||||||
|
|
||||||
def KrnlDefineLoopsOp : Op<Krnl_Dialect, "define_loops"> {
|
def KrnlDefineLoopsOp : Op<Krnl_Dialect, "define_loops"> {
|
||||||
let summary = "define_loops operation";
|
let summary = "define_loops operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
||||||
The "krnl.define_loops" operation is used to define input loops,
|
The "krnl.define_loops" operation is used to define input loops,
|
||||||
those are the for loops appearing in the input program that we
|
those are the for loops appearing in the input program that we
|
||||||
intend to optimize.
|
intend to optimize.
|
||||||
|
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins);
|
let arguments = (ins);
|
||||||
let results = (outs Variadic<AnyType>);
|
let results = (outs Variadic<AnyType>);
|
||||||
|
|
||||||
let skipDefaultBuilders = 1;
|
let skipDefaultBuilders = 1;
|
||||||
|
let builders = [ OpBuilder<"Builder *builder, OperationState &result,"
|
||||||
let builders = [
|
"int64_t num_loops"> ];
|
||||||
OpBuilder<"Builder *builder, OperationState &result,"
|
|
||||||
"int64_t num_loops">
|
|
||||||
];
|
|
||||||
|
|
||||||
let printer = [{ return ::print(p, *this); }];
|
let printer = [{ return ::print(p, *this); }];
|
||||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||||
|
@ -44,33 +36,27 @@ def KrnlDefineLoopsOp : Op<Krnl_Dialect, "define_loops"> {
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
static StringRef getNumLoopsAttrName() { return "num_loops"; }
|
static StringRef getNumLoopsAttrName() { return "num_loops"; }
|
||||||
|
|
||||||
// Helper function to extract the number of loops being defined.
|
// Helper function to extract the number of loops being defined.
|
||||||
int64_t getNumLoops() {
|
int64_t getNumLoops() {
|
||||||
auto num_loops =
|
auto num_loops = getAttrOfType<IntegerAttr>(getNumLoopsAttrName())
|
||||||
getAttrOfType<IntegerAttr>(
|
.getValue()
|
||||||
getNumLoopsAttrName())
|
.getSExtValue();
|
||||||
.getValue()
|
return num_loops;
|
||||||
.getSExtValue();
|
}
|
||||||
return num_loops;
|
}];
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def KrnlOptimizeLoopsOp : Op<Krnl_Dialect, "optimize_loops"> {
|
def KrnlOptimizeLoopsOp : Op<Krnl_Dialect, "optimize_loops"> {
|
||||||
let summary = "optimize_loops operation";
|
let summary = "optimize_loops operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
||||||
The "krnl.optimize_loops" operation is essentially a cosmetic operation
|
The "krnl.optimize_loops" operation is essentially a cosmetic operation
|
||||||
which exists to encapsulate a region where loops are being scheduled/optimized.
|
which exists to encapsulate a region where loops are being scheduled /
|
||||||
|
optimized.
|
||||||
|
|
||||||
The optimized loops are returned at the end of the
|
The optimized loops are returned at the end of the region associated with
|
||||||
region associated with the krnl.optimize_loops operation.
|
the krnl.optimize_loops operation.
|
||||||
|
|
||||||
For example:
|
|
||||||
TBD once we have actual schedule intrinsics.
|
|
||||||
|
|
||||||
|
For example : TBD once we have actual schedule intrinsics.
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins Variadic<AnyType>);
|
let arguments = (ins Variadic<AnyType>);
|
||||||
|
@ -79,10 +65,8 @@ def KrnlOptimizeLoopsOp : Op<Krnl_Dialect, "optimize_loops"> {
|
||||||
|
|
||||||
let skipDefaultBuilders = 1;
|
let skipDefaultBuilders = 1;
|
||||||
|
|
||||||
let builders = [
|
let builders = [ OpBuilder<"Builder *builder, OperationState &result, "
|
||||||
OpBuilder<"Builder *builder, OperationState &result, "
|
"int timestamp_space_rank"> ];
|
||||||
"int timestamp_space_rank">
|
|
||||||
];
|
|
||||||
|
|
||||||
let printer = [{ return ::print(p, *this); }];
|
let printer = [{ return ::print(p, *this); }];
|
||||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||||
|
@ -91,7 +75,6 @@ def KrnlOptimizeLoopsOp : Op<Krnl_Dialect, "optimize_loops"> {
|
||||||
def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> {
|
def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> {
|
||||||
let summary = "iterate operation";
|
let summary = "iterate operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
||||||
The "krnl.iterate" operation is conceptually equivalent to a nested for loops.
|
The "krnl.iterate" operation is conceptually equivalent to a nested for loops.
|
||||||
|
|
||||||
For instance, say we have the following two
|
For instance, say we have the following two
|
||||||
|
@ -103,25 +86,20 @@ def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> {
|
||||||
|
|
||||||
Then, consider the following krnl.iterate operation:
|
Then, consider the following krnl.iterate operation:
|
||||||
krnl.iterate (%o0, %o1) with (%l0 -> %i0 = 0 to 10, %l1 -> %i1 = 0 to 10) {
|
krnl.iterate (%o0, %o1) with (%l0 -> %i0 = 0 to 10, %l1 -> %i1 = 0 to 10) {
|
||||||
// Some operations.
|
// Some operations.
|
||||||
}
|
}
|
||||||
|
|
||||||
It is equivalent to:
|
It is equivalent to:
|
||||||
for (i0=0; i0<10; i0++)
|
for (i0 = 0; i0 < 10; i0++)
|
||||||
for (i1=0; i1<10; i1++)
|
for (i1 = 0; i1 < 10; i1++)
|
||||||
// Some operations.
|
// Some operations.
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins Variadic<AnyType>);
|
let arguments = (ins Variadic<AnyType>);
|
||||||
|
|
||||||
let regions = (region SizedRegion<1>:$bodyRegion);
|
let regions = (region SizedRegion<1>:$bodyRegion);
|
||||||
|
|
||||||
let skipDefaultBuilders = 1;
|
let skipDefaultBuilders = 1;
|
||||||
|
let builders = [ OpBuilder<"Builder *builder, OperationState &result, "
|
||||||
let builders = [
|
"KrnlIterateOperandPack operandPack"> ];
|
||||||
OpBuilder<"Builder *builder, OperationState &result, "
|
|
||||||
"KrnlIterateOperandPack operandPack">
|
|
||||||
];
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
// In krnl.iterate operation, operands are stored as such
|
// In krnl.iterate operation, operands are stored as such
|
||||||
|
@ -134,20 +112,19 @@ def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> {
|
||||||
|
|
||||||
int64_t getNumOptimizedLoops() {
|
int64_t getNumOptimizedLoops() {
|
||||||
auto num_optimized_loops =
|
auto num_optimized_loops =
|
||||||
getAttrOfType<IntegerAttr>(
|
getAttrOfType<IntegerAttr>(getNumOptimizedLoopsAttrName())
|
||||||
getNumOptimizedLoopsAttrName())
|
.getValue()
|
||||||
.getValue()
|
.getSExtValue();
|
||||||
.getSExtValue();
|
|
||||||
return num_optimized_loops;
|
return num_optimized_loops;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get name of the attribute for storing bound represented using affine maps.
|
// Get name of the attribute for storing bound represented using affine maps.
|
||||||
static StringRef getBoundsAttrName() { return "bounds"; }
|
static StringRef getBoundsAttrName() { return "bounds"; }
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let printer = [{ return ::print(p, *this); }];
|
let printer = [{ return ::print(p, *this); }];
|
||||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let verifier = [{ return ::verify(*this); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
def KrnlReturnLoopsOp : Op<Krnl_Dialect, "return_loops", [Terminator]> {
|
def KrnlReturnLoopsOp : Op<Krnl_Dialect, "return_loops", [Terminator]> {
|
||||||
|
@ -182,11 +159,30 @@ def KrnlTerminatorOp : Op<Krnl_Dialect, "terminate", [Terminator]> {
|
||||||
let verifier = ?;
|
let verifier = ?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def KrnlEntryPointOp : Op<Krnl_Dialect, "entry_point"> {
|
||||||
|
let summary = "Indicate ONNX entry point";
|
||||||
|
let description = [{The "krnl.entry_point" function indicates the main entry
|
||||||
|
point of ONNX model.}];
|
||||||
|
let builders = [ OpBuilder<"Builder *builder, OperationState &result, "
|
||||||
|
"SymbolRefAttr funcAttr, IntegerAttr numInputs, "
|
||||||
|
"IntegerAttr numOutputs"> ];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static StringRef getEntryPointFuncAttrName() { return "func"; }
|
||||||
|
static StringRef getNumInputsAttrName() { return "numInputs"; }
|
||||||
|
static StringRef getNumOutputsAttrName() { return "numOutputs"; }
|
||||||
|
}];
|
||||||
|
|
||||||
|
// No custom parsing/printing form.
|
||||||
|
let parser = ?;
|
||||||
|
let printer = ?;
|
||||||
|
}
|
||||||
|
|
||||||
def KrnlMemcpyOp : Op<Krnl_Dialect, "memcpy"> {
|
def KrnlMemcpyOp : Op<Krnl_Dialect, "memcpy"> {
|
||||||
let summary = "Krnl memcpy operation";
|
let summary = "Krnl memcpy operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
In the KRNL dialect the reshape op doesn't generate a new memory entry and
|
In the KRNL dialect the reshape op
|
||||||
treats a reshape like a cast.
|
doesn't generate a new memory entry and treats a reshape like a cast.
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins AnyMemRef:$dest, AnyMemRef:$src, AnyInteger:$size);
|
let arguments = (ins AnyMemRef:$dest, AnyMemRef:$src, AnyInteger:$size);
|
||||||
|
|
|
@ -58,6 +58,27 @@ class ONNX_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||||
|
|
||||||
include "dialect/onnx/onnxop.inc"
|
include "dialect/onnx/onnxop.inc"
|
||||||
|
|
||||||
|
// Indicate entry point functions of ONNX graph.
|
||||||
|
def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
|
||||||
|
let summary = "Indicate ONNX entry point";
|
||||||
|
let description = [{
|
||||||
|
The "onnx.EntryPoint" function indicates the main entry point of ONNX model.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let builders = [OpBuilder<[{Builder *builder, OperationState &state,
|
||||||
|
FuncOp function, int numInputs, int numOutputs}]>];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static ONNXEntryPointOp create(Location location, FuncOp& func,
|
||||||
|
int numInputs, int numOutputs);
|
||||||
|
|
||||||
|
static StringRef getEntryPointFuncAttrName() { return "func"; }
|
||||||
|
static StringRef getNumInputsAttrName() { return "numInputs"; }
|
||||||
|
static StringRef getNumOutputsAttrName() { return "numOutputs"; }
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def ONNXFullGemmOp: ONNX_Op<"FullGemm",
|
def ONNXFullGemmOp: ONNX_Op<"FullGemm",
|
||||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX general matrix multiply operation";
|
let summary = "ONNX general matrix multiply operation";
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
// This file defines ONNX operations in the MLIR operation set.
|
// This file defines ONNX operations in the MLIR operation set.
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "mlir/Dialect/Traits.h"
|
#include "mlir/Dialect/Traits.h"
|
||||||
#include "mlir/IR/Block.h"
|
#include "mlir/IR/Block.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
|
@ -38,6 +37,28 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx)
|
||||||
>();
|
>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ONNXEntryPointOp::build(mlir::Builder *builder,
|
||||||
|
mlir::OperationState &state, mlir::FuncOp function,
|
||||||
|
int numInputs, int numOutputs) {
|
||||||
|
state.addAttribute(ONNXEntryPointOp::getEntryPointFuncAttrName(),
|
||||||
|
builder->getSymbolRefAttr(function));
|
||||||
|
state.addAttribute(ONNXEntryPointOp::getNumInputsAttrName(),
|
||||||
|
builder->getI32IntegerAttr(numInputs));
|
||||||
|
state.addAttribute(ONNXEntryPointOp::getNumOutputsAttrName(),
|
||||||
|
builder->getI32IntegerAttr(numOutputs));
|
||||||
|
}
|
||||||
|
|
||||||
|
ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location,
|
||||||
|
mlir::FuncOp &func, int numInputs,
|
||||||
|
int numOutputs) {
|
||||||
|
mlir::OperationState state(location, "onnx.EntryPoint");
|
||||||
|
Builder builder(location->getContext());
|
||||||
|
mlir::ONNXEntryPointOp::build(&builder, state, func, numInputs, numOutputs);
|
||||||
|
Operation *op = mlir::Operation::create(state);
|
||||||
|
auto onnxEntryOp = llvm::cast<mlir::ONNXEntryPointOp>(op);
|
||||||
|
return onnxEntryOp;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ONNX Operations
|
// ONNX Operations
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
|
|
@ -8,7 +8,6 @@
|
||||||
// Krnl IR and standard operations.
|
// Krnl IR and standard operations.
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||||
|
@ -884,6 +883,27 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// EntryPoint Op lowering to Krnl Entry Point.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class ONNXEntryPointLowering : public OpRewritePattern<ONNXEntryPointOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<ONNXEntryPointOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
PatternMatchResult matchAndRewrite(ONNXEntryPointOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<KrnlEntryPointOp>(
|
||||||
|
op,
|
||||||
|
op.getAttrOfType<SymbolRefAttr>(
|
||||||
|
ONNXEntryPointOp::getEntryPointFuncAttrName()),
|
||||||
|
op.getAttrOfType<IntegerAttr>(ONNXEntryPointOp::getNumInputsAttrName()),
|
||||||
|
op.getAttrOfType<IntegerAttr>(
|
||||||
|
ONNXEntryPointOp::getNumOutputsAttrName()));
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Conversion from Tensor type to the Standard dialect MemRef type.
|
// Conversion from Tensor type to the Standard dialect MemRef type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -985,7 +1005,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
|
||||||
ONNXReshapeOpLowering>(&getContext());
|
ONNXReshapeOpLowering, ONNXEntryPointLowering>(&getContext());
|
||||||
|
|
||||||
// With the target and rewrite patterns defined, we can now attempt the
|
// With the target and rewrite patterns defined, we can now attempt the
|
||||||
// conversion. The conversion will signal failure if any of our `illegal`
|
// conversion. The conversion will signal failure if any of our `illegal`
|
||||||
|
|
|
@ -143,14 +143,16 @@ void KrnlToAffineLoweringPass::runOnFunction() {
|
||||||
// We expect IR to be free of Krnl Dialect Ops.
|
// We expect IR to be free of Krnl Dialect Ops.
|
||||||
target.addIllegalDialect<KrnlOpsDialect>();
|
target.addIllegalDialect<KrnlOpsDialect>();
|
||||||
target.addLegalOp<KrnlMemcpyOp>();
|
target.addLegalOp<KrnlMemcpyOp>();
|
||||||
|
target.addLegalOp<KrnlEntryPointOp>();
|
||||||
|
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
|
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
|
||||||
KrnlDefineLoopsLowering, KrnlOptimizeLoopsLowering>(
|
KrnlDefineLoopsLowering, KrnlOptimizeLoopsLowering>(
|
||||||
&getContext());
|
&getContext());
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getFunction(), target, patterns)))
|
if (failed(applyPartialConversion(getFunction(), target, patterns))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "llvm/ADT/Sequence.h"
|
|
||||||
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
||||||
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
||||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||||
|
@ -15,6 +14,7 @@
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "llvm/ADT/Sequence.h"
|
||||||
|
|
||||||
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
||||||
#include "src/compiler/pass/passes.hpp"
|
#include "src/compiler/pass/passes.hpp"
|
||||||
|
@ -23,20 +23,39 @@ using namespace mlir;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName,
|
||||||
|
ModuleOp module,
|
||||||
|
mlir::LLVM::LLVMType funcType,
|
||||||
|
PatternRewriter &rewriter) {
|
||||||
|
auto *context = module.getContext();
|
||||||
|
if (module.lookupSymbol<LLVM::LLVMFuncOp>(funcName)) {
|
||||||
|
auto symbolRef = SymbolRefAttr::get(funcName, context);
|
||||||
|
assert(symbolRef.getType() == funcType && "wrong symbol type");
|
||||||
|
return symbolRef;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert the function into the body of the parent module.
|
||||||
|
PatternRewriter::InsertionGuard insertGuard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(module.getBody());
|
||||||
|
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), funcName, funcType);
|
||||||
|
return SymbolRefAttr::get(funcName, context);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// KRNL to LLVM: patterns which need a direct lowering to LLVM.
|
// KRNL to LLVM: KrnlMemcpyOpLowering
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
class KrnlMemcpyOpLowering : public ConversionPattern {
|
class KrnlMemcpyOpLowering : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit KrnlMemcpyOpLowering(MLIRContext* context)
|
explicit KrnlMemcpyOpLowering(MLIRContext *context)
|
||||||
: ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {}
|
: ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {}
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
|
PatternMatchResult
|
||||||
ConversionPatternRewriter& rewriter) const override {
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||||
auto* context = op->getContext();
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto *context = op->getContext();
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
auto* llvmDialect =
|
auto *llvmDialect =
|
||||||
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||||
assert(llvmDialect && "expected llvm dialect to be registered");
|
assert(llvmDialect && "expected llvm dialect to be registered");
|
||||||
|
|
||||||
|
@ -47,39 +66,40 @@ class KrnlMemcpyOpLowering : public ConversionPattern {
|
||||||
// First operand.
|
// First operand.
|
||||||
Type dstType =
|
Type dstType =
|
||||||
operands[0]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
|
operands[0]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
|
||||||
Value* alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
|
Value *alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
|
||||||
loc, dstType, operands[0], rewriter.getI64ArrayAttr(1));
|
loc, dstType, operands[0], rewriter.getI64ArrayAttr(1));
|
||||||
Value* alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
|
Value *alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
|
||||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory);
|
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory);
|
||||||
|
|
||||||
// Second operand.
|
// Second operand.
|
||||||
Type srcType =
|
Type srcType =
|
||||||
operands[1]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
|
operands[1]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
|
||||||
Value* alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
|
Value *alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
|
||||||
loc, srcType, operands[1], rewriter.getI64ArrayAttr(1));
|
loc, srcType, operands[1], rewriter.getI64ArrayAttr(1));
|
||||||
Value* alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
|
Value *alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
|
||||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory);
|
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory);
|
||||||
|
|
||||||
// Size.
|
// Size.
|
||||||
Value* int64Size = rewriter.create<LLVM::SExtOp>(
|
Value *int64Size = rewriter.create<LLVM::SExtOp>(
|
||||||
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]);
|
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]);
|
||||||
|
|
||||||
// Memcpy call
|
// Memcpy call
|
||||||
rewriter.create<CallOp>(loc, memcpyRef,
|
rewriter.create<CallOp>(
|
||||||
LLVM::LLVMType::getVoidTy(llvmDialect),
|
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
|
||||||
ArrayRef<Value*>(
|
ArrayRef<Value *>(
|
||||||
{alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size}));
|
{alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size}));
|
||||||
|
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Return a symbol reference to the memcpy function, inserting it into the
|
/// Return a symbol reference to the memcpy function, inserting it into the
|
||||||
/// module if necessary.
|
/// module if necessary.
|
||||||
static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter& rewriter,
|
static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter,
|
||||||
ModuleOp module, LLVM::LLVMDialect* llvmDialect) {
|
ModuleOp module,
|
||||||
auto* context = module.getContext();
|
LLVM::LLVMDialect *llvmDialect) {
|
||||||
|
auto *context = module.getContext();
|
||||||
if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64"))
|
if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64"))
|
||||||
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
||||||
// Create a function declaration for memcpy, the signature is:
|
// Create a function declaration for memcpy, the signature is:
|
||||||
|
@ -87,19 +107,355 @@ class KrnlMemcpyOpLowering : public ConversionPattern {
|
||||||
auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect);
|
auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect);
|
||||||
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
||||||
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||||
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy,
|
auto llvmFnType = LLVM::LLVMType::getFunctionTy(
|
||||||
|
llvmVoidTy,
|
||||||
ArrayRef<mlir::LLVM::LLVMType>({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty}),
|
ArrayRef<mlir::LLVM::LLVMType>({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty}),
|
||||||
false);
|
false);
|
||||||
|
|
||||||
// Insert the memcpy function into the body of the parent module.
|
// Insert the memcpy function into the body of the parent module.
|
||||||
PatternRewriter::InsertionGuard insertGuard(rewriter);
|
PatternRewriter::InsertionGuard insertGuard(rewriter);
|
||||||
rewriter.setInsertionPointToStart(module.getBody());
|
rewriter.setInsertionPointToStart(module.getBody());
|
||||||
rewriter.create<LLVM::LLVMFuncOp>(
|
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(),
|
||||||
module.getLoc(), "llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
|
"llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
|
||||||
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // end namespace
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// KRNL to LLVM: KrnlEntryPointOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class KrnlEntryPointOpLowering : public OpRewritePattern<KrnlEntryPointOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<KrnlEntryPointOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
enum class API {
|
||||||
|
CREATE_ORDERED_DYN_MEM_REF_DICT,
|
||||||
|
CREATE_DYN_MEM_REF,
|
||||||
|
GET_DYN_MEM_REF,
|
||||||
|
SET_DYN_MEM_REF,
|
||||||
|
GET_DATA,
|
||||||
|
SET_DATA,
|
||||||
|
GET_SIZES,
|
||||||
|
GET_STRIDES,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ApiSpec {
|
||||||
|
API id;
|
||||||
|
std::string name;
|
||||||
|
FlatSymbolRefAttr symbolRef;
|
||||||
|
LLVM::LLVMType outputTy;
|
||||||
|
SmallVector<LLVM::LLVMType, 4> inputTys;
|
||||||
|
|
||||||
|
ApiSpec(API id, const std::string &name, LLVM::LLVMType outputTy,
|
||||||
|
ArrayRef<LLVM::LLVMType> inputTys)
|
||||||
|
: id(id), name(name), outputTy(outputTy),
|
||||||
|
inputTys(inputTys.begin(), inputTys.end()) {}
|
||||||
|
|
||||||
|
LLVM::LLVMType funcTy() {
|
||||||
|
return LLVM::LLVMType::getFunctionTy(outputTy, inputTys,
|
||||||
|
/*isVarArg=*/false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
PatternMatchResult matchAndRewrite(KrnlEntryPointOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
|
auto *llvmDialect =
|
||||||
|
op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||||
|
assert(llvmDialect && "expected llvm dialect to be registered");
|
||||||
|
auto module = op.getParentOfType<ModuleOp>();
|
||||||
|
auto apiRegistry = RegisterAllApis(module, rewriter, llvmDialect);
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
auto numOutputs =
|
||||||
|
op.getAttrOfType<IntegerAttr>(KrnlEntryPointOp::getNumOutputsAttrName())
|
||||||
|
.getInt();
|
||||||
|
|
||||||
|
using LLVMType = LLVM::LLVMType;
|
||||||
|
auto opaquePtrTy = LLVMType::getInt8PtrTy(llvmDialect);
|
||||||
|
auto int32Ty = LLVMType::getInt32Ty(llvmDialect);
|
||||||
|
|
||||||
|
// Rewrite Krnl Entry Point Operation to an LLVM function with a dynamic
|
||||||
|
// signature. The signature is dynamic because it remains the same no matter
|
||||||
|
// what the model input/output schema look like. Such dynamic signature
|
||||||
|
// takes a opaque ptr as input, representing a ptr to a data structure
|
||||||
|
// containing a set of dynamic memrefs wrapped in a vector; similarly the
|
||||||
|
// output is also a opaque ptr to a data structure with output memrefs
|
||||||
|
// wrapped within it.
|
||||||
|
auto staticEntryPointFuncName =
|
||||||
|
op.getAttrOfType<SymbolRefAttr>(
|
||||||
|
KrnlEntryPointOp::getEntryPointFuncAttrName())
|
||||||
|
.getLeafReference();
|
||||||
|
auto dynEntryPointName = "_dyn_entry_point_" + staticEntryPointFuncName;
|
||||||
|
assert(module.lookupSymbol(dynEntryPointName.str()) == nullptr &&
|
||||||
|
"dynamic entry point name is not unique");
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
auto dynEntryPointFuncTy =
|
||||||
|
LLVMType::getFunctionTy(opaquePtrTy, {opaquePtrTy}, false);
|
||||||
|
auto dynamicEntryPointFunc = rewriter.create<LLVM::LLVMFuncOp>(
|
||||||
|
loc, dynEntryPointName.str(), dynEntryPointFuncTy);
|
||||||
|
auto &entryPointEntryBlock =
|
||||||
|
createEntryBlock(dynEntryPointFuncTy, dynamicEntryPointFunc);
|
||||||
|
rewriter.setInsertionPointToStart(&entryPointEntryBlock);
|
||||||
|
|
||||||
|
// Based on the static entry point type signature, unpack dynamic memory
|
||||||
|
// refs to corresponding static memory refs.
|
||||||
|
auto *staticEntryPointFunc = module.lookupSymbol(staticEntryPointFuncName);
|
||||||
|
assert(staticEntryPointFunc &&
|
||||||
|
isa<LLVM::LLVMFuncOp>(staticEntryPointFunc) &&
|
||||||
|
"entry point func must exist and be an llvm func op");
|
||||||
|
auto staticEntryPointTy = dyn_cast<LLVM::LLVMFuncOp>(staticEntryPointFunc)
|
||||||
|
.getType()
|
||||||
|
.dyn_cast<LLVMType>();
|
||||||
|
|
||||||
|
// Retrieve dynamic mem refs from wrapped input, and convert every one of
|
||||||
|
// them to static mem refs.
|
||||||
|
SmallVector<Value *, 4> staticInputs;
|
||||||
|
auto wrappedInput = entryPointEntryBlock.getArgument(0);
|
||||||
|
for (size_t i = 0; i < staticEntryPointTy.getFunctionNumParams(); i++) {
|
||||||
|
// Call API function to retrieve the i-th dynamic memref.
|
||||||
|
auto idxVal = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, int32Ty, rewriter.getI32IntegerAttr(i));
|
||||||
|
auto dynMemRef = callApi(rewriter, loc, apiRegistry, API::GET_DYN_MEM_REF,
|
||||||
|
{wrappedInput, idxVal});
|
||||||
|
|
||||||
|
// Create a (static) memref type corresponding to the i-th memref input to
|
||||||
|
// the inference function on stack, and load it to memRef.
|
||||||
|
auto memRefPtrTy = staticEntryPointTy.getFunctionParamType(i);
|
||||||
|
auto memRefTy = memRefPtrTy.getPointerElementTy();
|
||||||
|
auto one = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, int32Ty, rewriter.getI32IntegerAttr(1));
|
||||||
|
Value *ptrToMemRef =
|
||||||
|
rewriter.create<LLVM::AllocaOp>(loc, memRefPtrTy, one,
|
||||||
|
/*alignment=*/0);
|
||||||
|
|
||||||
|
// Fill in the memref underlying ptrToMemRef with information extracted
|
||||||
|
// from dynMemRef.
|
||||||
|
fillPtrToMemRefWithDynMemRef(*dynMemRef, *ptrToMemRef, rewriter, loc,
|
||||||
|
apiRegistry, llvmDialect);
|
||||||
|
|
||||||
|
// ptrToMemRef will be an input to main computation graph function.
|
||||||
|
staticInputs.emplace_back(ptrToMemRef);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If more than one output exists, the struct becomes a nested struct,
|
||||||
|
// the unpacking logic can be more involved, so no support for now.
|
||||||
|
assert(numOutputs == 1 && "only support 1 output tensor now.");
|
||||||
|
|
||||||
|
// Call static entry point with the memref ptrs created, and get output.
|
||||||
|
auto outputMemRefs = rewriter.create<LLVM::CallOp>(
|
||||||
|
loc, staticEntryPointTy.getFunctionResultType(),
|
||||||
|
rewriter.getSymbolRefAttr(staticEntryPointFuncName), staticInputs);
|
||||||
|
|
||||||
|
// Create wrapped output.
|
||||||
|
auto wrappedOutput = callApi(rewriter, loc, apiRegistry,
|
||||||
|
API::CREATE_ORDERED_DYN_MEM_REF_DICT, {});
|
||||||
|
|
||||||
|
// Get the first memref returned, convert to a dynamic memref and store
|
||||||
|
// it in the wrapped Output.
|
||||||
|
auto outMemRef = outputMemRefs.getResult(0);
|
||||||
|
auto outMemRefTy = outMemRef->getType().dyn_cast<LLVMType>();
|
||||||
|
auto outMemRefRank =
|
||||||
|
outMemRefTy.getStructElementType(3).getArrayNumElements();
|
||||||
|
auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
|
||||||
|
auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
|
||||||
|
API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
|
||||||
|
fillDynMemRefWithMemRef(*outMemRef, *outDynMemRef, rewriter, loc,
|
||||||
|
apiRegistry, llvmDialect);
|
||||||
|
auto zero = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, int32Ty, rewriter.getI32IntegerAttr(0));
|
||||||
|
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
|
||||||
|
{wrappedOutput, zero, outDynMemRef});
|
||||||
|
|
||||||
|
// Return wrapped output.
|
||||||
|
rewriter.create<LLVM::ReturnOp>(loc,
|
||||||
|
SmallVector<Value *, 1>({wrappedOutput}));
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
using ApiRegistry = std::map<API, ApiSpec>;
|
||||||
|
|
||||||
|
ApiRegistry RegisterAllApis(ModuleOp &module, PatternRewriter &rewriter,
|
||||||
|
LLVM::LLVMDialect *llvmDialect) const {
|
||||||
|
using LLVMType = LLVM::LLVMType;
|
||||||
|
auto voidTy = LLVMType::getVoidTy(llvmDialect);
|
||||||
|
auto opaquePtrTy = LLVMType::getInt8PtrTy(llvmDialect);
|
||||||
|
auto int32Ty = LLVMType::getInt32Ty(llvmDialect);
|
||||||
|
auto int64Ty = LLVMType::getInt64Ty(llvmDialect);
|
||||||
|
auto int64PtrTy = int64Ty.getPointerTo();
|
||||||
|
|
||||||
|
// Declare API type as an enum value, its string name and an LLVM Type
|
||||||
|
// specifying its signature.
|
||||||
|
// clang-format off
|
||||||
|
std::vector<ApiSpec> apiSpecs = {
|
||||||
|
ApiSpec(API::CREATE_ORDERED_DYN_MEM_REF_DICT, "createOrderedDynMemRefDict", opaquePtrTy, {}),
|
||||||
|
ApiSpec(API::CREATE_DYN_MEM_REF, "createDynMemRef", opaquePtrTy, {int32Ty}),
|
||||||
|
ApiSpec(API::GET_DATA, "getData", opaquePtrTy, {opaquePtrTy}),
|
||||||
|
ApiSpec(API::SET_DATA, "setData", voidTy, {opaquePtrTy, opaquePtrTy}),
|
||||||
|
ApiSpec(API::GET_DYN_MEM_REF, "getDynMemRef", opaquePtrTy, {opaquePtrTy, int32Ty}),
|
||||||
|
ApiSpec(API::SET_DYN_MEM_REF, "setDynMemRef", voidTy, {opaquePtrTy, int32Ty, opaquePtrTy}),
|
||||||
|
ApiSpec(API::GET_SIZES, "getSizes", int64PtrTy, {opaquePtrTy}),
|
||||||
|
ApiSpec(API::GET_STRIDES, "getStrides", int64PtrTy, {opaquePtrTy})
|
||||||
|
};
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
// Declare APIs in the current module and build an API registry mapping api
|
||||||
|
// identities to a symbol reference to the API function.
|
||||||
|
ApiRegistry registry;
|
||||||
|
for (auto &apiSpec : apiSpecs) {
|
||||||
|
apiSpec.symbolRef = getOrInsertExternFunc(apiSpec.name, module,
|
||||||
|
apiSpec.funcTy(), rewriter);
|
||||||
|
registry.emplace(apiSpec.id, apiSpec);
|
||||||
|
}
|
||||||
|
|
||||||
|
return registry;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call a registered API, return the return SSA values if only one result is
|
||||||
|
// returned, otherwise return nullptr.
|
||||||
|
Value *callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry,
|
||||||
|
API apiId, ArrayRef<Value *> params) const {
|
||||||
|
auto returnVals = rewriter.create<LLVM::CallOp>(
|
||||||
|
loc, registry.at(apiId).outputTy, registry.at(apiId).symbolRef,
|
||||||
|
ArrayRef<Value *>(params));
|
||||||
|
if (returnVals.getNumResults() == 1)
|
||||||
|
return returnVals.getResult(0);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to insert an entry block to LLVM function.
|
||||||
|
// (TODO): upstream this to MLIR.
|
||||||
|
Block &createEntryBlock(LLVM::LLVMType &dynEntryPointFuncType,
|
||||||
|
LLVM::LLVMFuncOp &dynamicEntryPointFunc) const {
|
||||||
|
// Add entry block:
|
||||||
|
auto *entryPointEntryBlock = new Block();
|
||||||
|
dynamicEntryPointFunc.push_back(entryPointEntryBlock);
|
||||||
|
llvm::SmallVector<Type, 4> argTypes;
|
||||||
|
for (size_t i = 0; i < dynEntryPointFuncType.getFunctionNumParams(); i++)
|
||||||
|
argTypes.emplace_back(dynEntryPointFuncType.getFunctionParamType(i));
|
||||||
|
entryPointEntryBlock->addArguments(argTypes);
|
||||||
|
return *entryPointEntryBlock;
|
||||||
|
}
|
||||||
|
|
||||||
|
void fillPtrToMemRefWithDynMemRef(Value &dynMemRef, Value &ptrToMemRef,
|
||||||
|
PatternRewriter &rewriter,
|
||||||
|
const Location &loc,
|
||||||
|
const std::map<API, ApiSpec> &apiRegistry,
|
||||||
|
LLVM::LLVMDialect *llvmDialect) const {
|
||||||
|
auto memRefPtrTy = ptrToMemRef.getType().dyn_cast<LLVM::LLVMType>();
|
||||||
|
auto memRefTy = memRefPtrTy.getPointerElementTy();
|
||||||
|
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||||
|
|
||||||
|
Value *memRef =
|
||||||
|
rewriter.create<LLVM::LoadOp>(loc, memRefPtrTy, &ptrToMemRef);
|
||||||
|
|
||||||
|
// Set dataPtr and alignedDataPtr;
|
||||||
|
auto dataPtr =
|
||||||
|
callApi(rewriter, loc, apiRegistry, API::GET_DATA, {&dynMemRef});
|
||||||
|
dataPtr = rewriter.create<LLVM::BitcastOp>(
|
||||||
|
loc, memRefTy.getStructElementType(0), dataPtr);
|
||||||
|
memRef = rewriter.create<LLVM::InsertValueOp>(
|
||||||
|
loc, memRefTy, memRef, dataPtr,
|
||||||
|
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(0)}));
|
||||||
|
memRef = rewriter.create<LLVM::InsertValueOp>(
|
||||||
|
loc, memRefTy, memRef, dataPtr,
|
||||||
|
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(1)}));
|
||||||
|
|
||||||
|
// Use zero offset now.
|
||||||
|
auto zero = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, int64Ty, rewriter.getI64IntegerAttr(0));
|
||||||
|
memRef = rewriter.create<LLVM::InsertValueOp>(
|
||||||
|
loc, memRefTy, memRef, zero,
|
||||||
|
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(2)}));
|
||||||
|
|
||||||
|
// Get rank, sizes array ptr and strides array ptr.
|
||||||
|
auto rank = memRefTy.getStructElementType(3).getArrayNumElements();
|
||||||
|
auto sizesArrayPtr =
|
||||||
|
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {&dynMemRef});
|
||||||
|
auto stridesArrayPtr =
|
||||||
|
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {&dynMemRef});
|
||||||
|
|
||||||
|
for (decltype(rank) i = 0; i < rank; i++) {
|
||||||
|
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, int64Ty, rewriter.getI64IntegerAttr(i));
|
||||||
|
|
||||||
|
// Insert size of the dimension.
|
||||||
|
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
|
||||||
|
loc, int64Ty.getPointerTo(), sizesArrayPtr,
|
||||||
|
ArrayRef<Value *>({dimIdx}));
|
||||||
|
auto dimSize = rewriter.create<LLVM::LoadOp>(loc, int64Ty.getPointerTo(),
|
||||||
|
dimSizePtr);
|
||||||
|
memRef = rewriter.create<LLVM::InsertValueOp>(
|
||||||
|
loc, memRefTy, memRef, dimSize,
|
||||||
|
rewriter.getArrayAttr(
|
||||||
|
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
|
||||||
|
|
||||||
|
// Insert stride of the dimension.
|
||||||
|
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
|
||||||
|
loc, int64Ty.getPointerTo(), sizesArrayPtr,
|
||||||
|
ArrayRef<Value *>({dimIdx}));
|
||||||
|
auto dimStride = rewriter.create<LLVM::LoadOp>(
|
||||||
|
loc, int64Ty.getPointerTo(), dimStridePtr);
|
||||||
|
memRef = rewriter.create<LLVM::InsertValueOp>(
|
||||||
|
loc, memRefTy, memRef, dimStride,
|
||||||
|
rewriter.getArrayAttr(
|
||||||
|
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.create<LLVM::StoreOp>(loc, memRef, &ptrToMemRef);
|
||||||
|
}
|
||||||
|
|
||||||
|
void fillDynMemRefWithMemRef(Value &outMemRef, Value &outDynMemRef,
|
||||||
|
PatternRewriter &rewriter, const Location &loc,
|
||||||
|
const std::map<API, ApiSpec> &apiRegistry,
|
||||||
|
LLVM::LLVMDialect *llvmDialect) const {
|
||||||
|
auto outMemRefTy = outMemRef.getType().dyn_cast<LLVM::LLVMType>();
|
||||||
|
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||||
|
|
||||||
|
// Extract the data pointer, and record it in dynamic mem ref created.
|
||||||
|
Value *outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(
|
||||||
|
loc, outMemRefTy.getStructElementType(0), &outMemRef,
|
||||||
|
rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)}));
|
||||||
|
outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>(
|
||||||
|
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr);
|
||||||
|
callApi(rewriter, loc, apiRegistry, API::SET_DATA,
|
||||||
|
{&outDynMemRef, outMemRefDataPtr});
|
||||||
|
|
||||||
|
auto rank = outMemRefTy.getStructElementType(3).getArrayNumElements();
|
||||||
|
auto sizesArrayPtr =
|
||||||
|
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {&outDynMemRef});
|
||||||
|
auto stridesArrayPtr =
|
||||||
|
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {&outDynMemRef});
|
||||||
|
|
||||||
|
for (decltype(rank) i = 0; i < rank; i++) {
|
||||||
|
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, int64Ty, rewriter.getI64IntegerAttr(i));
|
||||||
|
|
||||||
|
// Transfer size of dimension from memref to dynamic memref.
|
||||||
|
auto dimSize = rewriter.create<LLVM::ExtractValueOp>(
|
||||||
|
loc, int64Ty, &outMemRef,
|
||||||
|
rewriter.getArrayAttr(
|
||||||
|
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
|
||||||
|
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
|
||||||
|
loc, int64Ty.getPointerTo(), sizesArrayPtr,
|
||||||
|
ArrayRef<Value *>({dimIdx}));
|
||||||
|
rewriter.create<LLVM::StoreOp>(loc, dimSize, dimSizePtr);
|
||||||
|
|
||||||
|
// Transfer stride of dimension from memref to dynamic memref.
|
||||||
|
auto dimStride = rewriter.create<LLVM::ExtractValueOp>(
|
||||||
|
loc, int64Ty, &outMemRef,
|
||||||
|
rewriter.getArrayAttr(
|
||||||
|
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
|
||||||
|
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
|
||||||
|
loc, int64Ty.getPointerTo(), stridesArrayPtr,
|
||||||
|
ArrayRef<Value *>({dimIdx}));
|
||||||
|
rewriter.create<LLVM::StoreOp>(loc, dimStride, dimStridePtr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // end namespace
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// KRNL + Stadard + Affine dialects lowering to LLVM.
|
// KRNL + Stadard + Affine dialects lowering to LLVM.
|
||||||
|
@ -109,7 +465,7 @@ namespace {
|
||||||
struct KrnlToLLVMLoweringPass : public ModulePass<KrnlToLLVMLoweringPass> {
|
struct KrnlToLLVMLoweringPass : public ModulePass<KrnlToLLVMLoweringPass> {
|
||||||
void runOnModule() final;
|
void runOnModule() final;
|
||||||
};
|
};
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
void KrnlToLLVMLoweringPass::runOnModule() {
|
void KrnlToLLVMLoweringPass::runOnModule() {
|
||||||
// Define the target for this lowering i.e. the LLVM dialect.
|
// Define the target for this lowering i.e. the LLVM dialect.
|
||||||
|
@ -128,12 +484,13 @@ void KrnlToLLVMLoweringPass::runOnModule() {
|
||||||
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||||
|
|
||||||
// Lower from the `krnl` dialect i.e. the Reshape operation.
|
// Lower from the `krnl` dialect i.e. the Reshape operation.
|
||||||
patterns.insert<KrnlMemcpyOpLowering>(&getContext());
|
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering>(
|
||||||
|
&getContext());
|
||||||
|
|
||||||
// We want to completely lower to LLVM, so we use a `FullConversion`. This
|
// We want to completely lower to LLVM, so we use a `FullConversion`. This
|
||||||
// ensures that only legal operations will remain after the conversion.
|
// ensures that only legal operations will remain after the conversion.
|
||||||
auto module = getModule();
|
if (failed(
|
||||||
if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
|
applyFullConversion(getModule(), target, patterns, &typeConverter)))
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -142,5 +499,5 @@ std::unique_ptr<mlir::Pass> mlir::createKrnlLowerToLLVMPass() {
|
||||||
return std::make_unique<KrnlToLLVMLoweringPass>();
|
return std::make_unique<KrnlToLLVMLoweringPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<KrnlToLLVMLoweringPass> pass(
|
static PassRegistration<KrnlToLLVMLoweringPass>
|
||||||
"lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM.");
|
pass("lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM.");
|
||||||
|
|
|
@ -7,10 +7,7 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstdlib>
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <random>
|
|
||||||
#include <tuple>
|
|
||||||
|
|
||||||
#include "llvm/Bitcode/BitcodeWriter.h"
|
#include "llvm/Bitcode/BitcodeWriter.h"
|
||||||
#include "llvm/Support/CommandLine.h"
|
#include "llvm/Support/CommandLine.h"
|
||||||
|
@ -24,9 +21,7 @@
|
||||||
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||||
#include "src/compiler/pass/passes.hpp"
|
#include "src/compiler/pass/passes.hpp"
|
||||||
|
|
||||||
#include "mlir/Analysis/Verifier.h"
|
|
||||||
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
||||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
|
||||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
add_library(cruntime
|
||||||
|
dyn_memref.cpp
|
||||||
|
dyn_memref.h
|
||||||
|
data_type.h)
|
||||||
|
target_include_directories(cruntime
|
||||||
|
PRIVATE ${DLC_SRC_ROOT} ${DLC_BIN_ROOT}
|
||||||
|
${DLC_SRC_ROOT})
|
||||||
|
|
||||||
|
pybind11_add_module(pyruntime
|
||||||
|
dyn_memref.cpp
|
||||||
|
dyn_memref.h
|
||||||
|
runtime.cpp
|
||||||
|
runtime.hpp)
|
||||||
|
target_link_libraries(pyruntime PRIVATE ${CMAKE_DL_LIBS})
|
||||||
|
target_include_directories(pyruntime
|
||||||
|
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
||||||
|
${ONNF_SRC_ROOT})
|
|
@ -0,0 +1,30 @@
|
||||||
|
enum DYN_MEMREF_DATA_TYPE {
|
||||||
|
UNDEFINED = 0;
|
||||||
|
// Basic types.
|
||||||
|
FLOAT = 1; // float
|
||||||
|
UINT8 = 2; // uint8_t
|
||||||
|
INT8 = 3; // int8_t
|
||||||
|
UINT16 = 4; // uint16_t
|
||||||
|
INT16 = 5; // int16_t
|
||||||
|
INT32 = 6; // int32_t
|
||||||
|
INT64 = 7; // int64_t
|
||||||
|
STRING = 8; // string
|
||||||
|
BOOL = 9; // bool
|
||||||
|
|
||||||
|
// IEEE754 half-precision floating-point format (16 bits wide).
|
||||||
|
// This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
|
||||||
|
FLOAT16 = 10;
|
||||||
|
|
||||||
|
DOUBLE = 11;
|
||||||
|
UINT32 = 12;
|
||||||
|
UINT64 = 13;
|
||||||
|
COMPLEX64 = 14; // complex with float32 real and imaginary components
|
||||||
|
COMPLEX128 = 15; // complex with float64 real and imaginary components
|
||||||
|
|
||||||
|
// Non-IEEE floating-point format based on IEEE754 single-precision
|
||||||
|
// floating-point number truncated to 16 bits.
|
||||||
|
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
|
||||||
|
BFLOAT16 = 16;
|
||||||
|
|
||||||
|
// Future extensions go here.
|
||||||
|
};
|
|
@ -0,0 +1,74 @@
|
||||||
|
#include <cassert>
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "dyn_memref.h"
|
||||||
|
|
||||||
|
DynMemRef::DynMemRef(int _rank) {
|
||||||
|
rank = _rank;
|
||||||
|
sizes = (INDEX_TYPE *)malloc(rank * sizeof(INDEX_TYPE));
|
||||||
|
strides = (int64_t *)malloc(rank * sizeof(int64_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
// An ordered dynamic MemRef dictionary.
|
||||||
|
// The goal is to support accessing dynamic memory ref by name and by index.
|
||||||
|
// Currently, only accessing by index is supported.
|
||||||
|
struct OrderedDynMemRefDict {
|
||||||
|
std::map<std::string, DynMemRef *> tensorDict;
|
||||||
|
std::vector<std::string> orderedNames;
|
||||||
|
};
|
||||||
|
|
||||||
|
int numDynMemRefs(OrderedDynMemRefDict *dict) {
|
||||||
|
return dict->orderedNames.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
OrderedDynMemRefDict *createOrderedDynMemRefDict() {
|
||||||
|
return new OrderedDynMemRefDict();
|
||||||
|
}
|
||||||
|
|
||||||
|
DynMemRef *createDynMemRef(int rank) { return new DynMemRef(rank); }
|
||||||
|
|
||||||
|
DynMemRef *getDynMemRef(OrderedDynMemRefDict *tensorDict, int idx) {
|
||||||
|
return tensorDict->tensorDict[tensorDict->orderedNames[idx]];
|
||||||
|
}
|
||||||
|
|
||||||
|
void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx,
|
||||||
|
DynMemRef *tensor) {
|
||||||
|
if (tensorDict->orderedNames.capacity() <= idx)
|
||||||
|
tensorDict->orderedNames.resize(idx + 1);
|
||||||
|
|
||||||
|
// The dynamic memref is essentially anonymous, since we are storing it by
|
||||||
|
// indexed position.
|
||||||
|
// TODO: can use random string as names to reduce chance of collision.
|
||||||
|
auto unique_name = std::to_string(idx);
|
||||||
|
assert(tensorDict->tensorDict.count(unique_name) == 0 &&
|
||||||
|
"duplicate dynamic mem ref name");
|
||||||
|
|
||||||
|
tensorDict->orderedNames[idx] = unique_name;
|
||||||
|
tensorDict->tensorDict[tensorDict->orderedNames[idx]] = tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
void *getData(DynMemRef *dynMemRef) { return dynMemRef->data; }
|
||||||
|
|
||||||
|
void setData(DynMemRef *dynMemRef, void *dataPtr) { dynMemRef->data = dataPtr; }
|
||||||
|
|
||||||
|
void *getAlignedData(DynMemRef *dynMemRef) { return dynMemRef->alignedData; }
|
||||||
|
|
||||||
|
void setAlignedData(DynMemRef *dynMemRef, void *dataPtr) {
|
||||||
|
dynMemRef->alignedData = dataPtr;
|
||||||
|
}
|
||||||
|
|
||||||
|
INDEX_TYPE *getSizes(DynMemRef *dynMemRef) { return dynMemRef->sizes; }
|
||||||
|
|
||||||
|
void setSizes(DynMemRef *dynMemRef, INDEX_TYPE *sizes) {
|
||||||
|
for (int i = 0; i < dynMemRef->rank; i++)
|
||||||
|
dynMemRef->sizes[i] = sizes[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t *getStrides(DynMemRef *dynMemRef) { return dynMemRef->strides; }
|
||||||
|
|
||||||
|
void setStrides(DynMemRef *dynMemRef, int64_t *strides) {
|
||||||
|
for (int i = 0; i < dynMemRef->rank; i++)
|
||||||
|
dynMemRef->sizes[i] = strides[i];
|
||||||
|
}
|
|
@ -0,0 +1,61 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
typedef int64_t INDEX_TYPE;
|
||||||
|
|
||||||
|
// This is a dynamic version of memref.
|
||||||
|
// The same struct can be used to represent memrefs of
|
||||||
|
// all ranks and type combinations.
|
||||||
|
struct DynMemRef {
|
||||||
|
void *data;
|
||||||
|
void *alignedData;
|
||||||
|
INDEX_TYPE offset;
|
||||||
|
|
||||||
|
unsigned int rank;
|
||||||
|
INDEX_TYPE *sizes;
|
||||||
|
int64_t *strides;
|
||||||
|
|
||||||
|
DynMemRef(int _rank);
|
||||||
|
};
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
// Ordered DynMemRef Dictionary is a data structure for wrapping the input
|
||||||
|
// dynmemrefs so that they can be addressed both by index and by name.
|
||||||
|
struct OrderedDynMemRefDict;
|
||||||
|
|
||||||
|
// Get number of dynamic memrefs in OrderedDynMemRefDict dict.
|
||||||
|
int numDynMemRefs(OrderedDynMemRefDict *dict);
|
||||||
|
|
||||||
|
// Create an ordered dynmemref dictionary.
|
||||||
|
OrderedDynMemRefDict *createOrderedDynMemRefDict();
|
||||||
|
|
||||||
|
// Create a dynmemref with a certain rank.
|
||||||
|
DynMemRef *createDynMemRef(int rank);
|
||||||
|
|
||||||
|
// Get the i-th dynmemref from orderedDict.
|
||||||
|
DynMemRef *getDynMemRef(OrderedDynMemRefDict *orderedDict, int i);
|
||||||
|
|
||||||
|
// Set the i-th dynmemref in orderedDict to be dynMemRef.
|
||||||
|
void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx,
|
||||||
|
DynMemRef *dynMemRef);
|
||||||
|
|
||||||
|
// Get data pointer from dynMemRef.
|
||||||
|
void *getData(DynMemRef *dynMemRef);
|
||||||
|
|
||||||
|
// Set data pointer for dynMemRef.
|
||||||
|
void setData(DynMemRef *dynMemRef, void *data);
|
||||||
|
|
||||||
|
// Get algined data pointer from dynMemRef.
|
||||||
|
void *getAlignedData(DynMemRef *);
|
||||||
|
|
||||||
|
// Set aligned data pointer for dynMemRef.
|
||||||
|
void setAlignedData(DynMemRef *, void *);
|
||||||
|
|
||||||
|
// Get ptr to sizes array.
|
||||||
|
INDEX_TYPE *getSizes(DynMemRef *);
|
||||||
|
|
||||||
|
// Get ptr to strides array.
|
||||||
|
int64_t *getStrides(DynMemRef *);
|
||||||
|
}
|
|
@ -0,0 +1,52 @@
|
||||||
|
#include "runtime.hpp"
|
||||||
|
|
||||||
|
ExecutionSession::ExecutionSession(std::string sharedLibPath,
|
||||||
|
std::string entryPointName) {
|
||||||
|
_sharedLibraryHandle = dlopen(sharedLibPath.c_str(), RTLD_LAZY);
|
||||||
|
_entryPointFunc =
|
||||||
|
(entryPointFuncType)dlsym(_sharedLibraryHandle, entryPointName.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<py::array>
|
||||||
|
ExecutionSession::run(std::vector<py::array> inputsPyArray) {
|
||||||
|
assert(_entryPointFunc && "entry point not loaded");
|
||||||
|
auto *wrappedInput = createOrderedDynMemRefDict();
|
||||||
|
int inputIdx = 0;
|
||||||
|
for (auto inputPyArray : inputsPyArray) {
|
||||||
|
auto *inputDynMemRef = createDynMemRef(inputPyArray.ndim());
|
||||||
|
assert(inputPyArray.flags() && py::array::c_style &&
|
||||||
|
"expect contiguous python array");
|
||||||
|
|
||||||
|
if (inputPyArray.writeable()) {
|
||||||
|
inputDynMemRef->data = inputPyArray.mutable_data();
|
||||||
|
inputDynMemRef->alignedData = inputPyArray.mutable_data();
|
||||||
|
} else {
|
||||||
|
// If data is not writable, copy them to a writable buffer.
|
||||||
|
auto *copiedData = (float *)malloc(inputPyArray.nbytes());
|
||||||
|
memcpy(copiedData, inputPyArray.data(), inputPyArray.nbytes());
|
||||||
|
inputDynMemRef->data = copiedData;
|
||||||
|
inputDynMemRef->alignedData = copiedData;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < inputPyArray.ndim(); i++) {
|
||||||
|
inputDynMemRef->sizes[i] = inputPyArray.shape(i);
|
||||||
|
inputDynMemRef->strides[i] = inputPyArray.strides(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
setDynMemRef(wrappedInput, inputIdx++, inputDynMemRef);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<py::array> outputPyArrays;
|
||||||
|
auto *wrappedOutput = _entryPointFunc(wrappedInput);
|
||||||
|
for (int i = 0; i < numDynMemRefs(wrappedOutput); i++) {
|
||||||
|
auto *dynMemRef = getDynMemRef(wrappedOutput, i);
|
||||||
|
auto shape = std::vector<int64_t>(dynMemRef->sizes,
|
||||||
|
dynMemRef->sizes + dynMemRef->rank);
|
||||||
|
outputPyArrays.emplace_back(
|
||||||
|
py::array(py::dtype("float32"), shape, dynMemRef->data));
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputPyArrays;
|
||||||
|
}
|
||||||
|
|
||||||
|
ExecutionSession::~ExecutionSession() { dlclose(_sharedLibraryHandle); }
|
|
@ -0,0 +1,37 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include <dlfcn.h>
|
||||||
|
#include <pybind11/numpy.h>
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
#include <pybind11/stl.h>
|
||||||
|
|
||||||
|
#include "src/runtime/dyn_memref.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
typedef OrderedDynMemRefDict *(*entryPointFuncType)(OrderedDynMemRefDict *);
|
||||||
|
|
||||||
|
class ExecutionSession {
|
||||||
|
public:
|
||||||
|
ExecutionSession(std::string sharedLibPath, std::string entryPointName);
|
||||||
|
|
||||||
|
std::vector<py::array> run(std::vector<py::array> inputsPyArray);
|
||||||
|
|
||||||
|
~ExecutionSession();
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Handler to the shared library file being loaded.
|
||||||
|
void *_sharedLibraryHandle = nullptr;
|
||||||
|
|
||||||
|
// Entry point function.
|
||||||
|
entryPointFuncType _entryPointFunc = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
PYBIND11_MODULE(pyruntime, m) {
|
||||||
|
py::class_<ExecutionSession>(m, "ExecutionSession")
|
||||||
|
.def(py::init<const std::string &, const std::string &>())
|
||||||
|
.def("run", &ExecutionSession::run);
|
||||||
|
}
|
|
@ -0,0 +1,162 @@
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
import onnx.backend.base
|
||||||
|
import onnx.backend.test
|
||||||
|
|
||||||
|
from onnx.backend.base import Device, DeviceType
|
||||||
|
import onnx.shape_inference
|
||||||
|
import onnx.version_converter
|
||||||
|
from typing import Optional, Text, Any, Tuple, Sequence
|
||||||
|
from onnx import NodeProto, ModelProto, TensorProto
|
||||||
|
import numpy # type: ignore
|
||||||
|
import subprocess
|
||||||
|
from pyruntime import ExecutionSession
|
||||||
|
|
||||||
|
CXX = os.getenv('CXX')
|
||||||
|
FE = os.getenv('FE')
|
||||||
|
LLC = os.getenv('LLC')
|
||||||
|
RT_DIR = os.getenv('RT_DIR')
|
||||||
|
assert CXX and FE and LLC and RT_DIR, "tools path not set"
|
||||||
|
|
||||||
|
class DummyBackend(onnx.backend.base.Backend):
|
||||||
|
@classmethod
|
||||||
|
def prepare(
|
||||||
|
cls,
|
||||||
|
model,
|
||||||
|
device='CPU',
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super(DummyBackend, cls).prepare(model, device, **kwargs)
|
||||||
|
# Save model to disk as temp_model.onnx.
|
||||||
|
onnx.save(model, "temp_model.onnx")
|
||||||
|
# Call frontend to process temp_model.onnx, bit code will be generated.
|
||||||
|
subprocess.run([FE, "temp_model.onnx"], stdout=subprocess.PIPE)
|
||||||
|
# Call llc to generate object file from bitcode.
|
||||||
|
subprocess.run([LLC, "-filetype=obj", "model.bc"],
|
||||||
|
stdout=subprocess.PIPE)
|
||||||
|
# Generate shared library from object file, linking with c runtime.
|
||||||
|
subprocess.run([
|
||||||
|
CXX, "-shared", "model.o", "-o", "model.so", "-L" + RT_DIR,
|
||||||
|
"-lcruntime"
|
||||||
|
],
|
||||||
|
stdout=subprocess.PIPE)
|
||||||
|
return ExecutionSession("./model.so", "_dyn_entry_point_main_graph")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_device(cls, device):
|
||||||
|
d = Device(device)
|
||||||
|
if d.type == DeviceType.CPU:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
backend_test = onnx.backend.test.BackendTest(DummyBackend, __name__)
|
||||||
|
|
||||||
|
# Test directories:
|
||||||
|
# https://github.com/onnx/onnx/tree/master/onnx/backend/test/data/node
|
||||||
|
|
||||||
|
test_to_enable = [
|
||||||
|
# Add Op:
|
||||||
|
"test_add_cpu",
|
||||||
|
"test_add_bcast_cpu",
|
||||||
|
|
||||||
|
# And Op:
|
||||||
|
|
||||||
|
# Sub Op:
|
||||||
|
"test_sub_cpu",
|
||||||
|
"test_sub_bcast_cpu",
|
||||||
|
"test_sub_example_cpu",
|
||||||
|
|
||||||
|
# Cosh Op:
|
||||||
|
"test_cosh_cpu",
|
||||||
|
"test_cosh_example_cpu",
|
||||||
|
|
||||||
|
# Div Op:
|
||||||
|
"test_div_cpu",
|
||||||
|
"test_div_bcast_cpu",
|
||||||
|
"test_div_example_cpu",
|
||||||
|
|
||||||
|
# Elu Op:
|
||||||
|
"test_elu_cpu",
|
||||||
|
"test_elu_default_cpu",
|
||||||
|
"test_elu_example_cpu",
|
||||||
|
|
||||||
|
# Exp Op:
|
||||||
|
"test_exp_cpu",
|
||||||
|
"test_exp_example_cpu",
|
||||||
|
|
||||||
|
# Hard Sigmoid Op:
|
||||||
|
"test_hardsigmoid_cpu",
|
||||||
|
"test_hardsigmoid_default_cpu",
|
||||||
|
"test_hardsigmoid_example_cpu",
|
||||||
|
|
||||||
|
# Leaky Relu Op:
|
||||||
|
"test_leakyrelu_cpu",
|
||||||
|
"test_leakyrelu_default_cpu",
|
||||||
|
"test_leakyrelu_example_cpu",
|
||||||
|
|
||||||
|
# Max Op:
|
||||||
|
# "test_max_example_cpu", <- error
|
||||||
|
"test_max_one_input_cpu",
|
||||||
|
# "test_max_two_inputs_cpu", <- error
|
||||||
|
|
||||||
|
# Min Op:
|
||||||
|
# "test_min_example_cpu", <- error
|
||||||
|
"test_min_one_input_cpu",
|
||||||
|
# "test_min_two_inputs_cpu", <- error
|
||||||
|
|
||||||
|
# Mul Op:
|
||||||
|
"test_mul_cpu",
|
||||||
|
"test_mul_bcast_cpu",
|
||||||
|
"test_mul_example_cpu",
|
||||||
|
|
||||||
|
# Relu Op:
|
||||||
|
"test_relu_cpu",
|
||||||
|
|
||||||
|
# Selu Op:
|
||||||
|
"test_selu_cpu",
|
||||||
|
"test_selu_default_cpu",
|
||||||
|
"test_selu_example_cpu",
|
||||||
|
|
||||||
|
# Sigmoid Op:
|
||||||
|
"test_sigmoid_cpu",
|
||||||
|
"test_sigmoid_example_cpu",
|
||||||
|
|
||||||
|
# Sum Op:
|
||||||
|
#"test_sum_example_cpu", <- error
|
||||||
|
"test_sum_one_input_cpu",
|
||||||
|
#"test_sum_two_inputs_cpu", <- error
|
||||||
|
|
||||||
|
# Reciprocal Op:
|
||||||
|
#"test_reciprocal_cpu", <- error on shape inference.
|
||||||
|
#"test_reciprocal_example_cpu", <- error on shape inference.
|
||||||
|
]
|
||||||
|
|
||||||
|
# Extract name of all test cases.
|
||||||
|
import inspect
|
||||||
|
all_tests = inspect.getmembers(
|
||||||
|
backend_test.test_cases["OnnxBackendNodeModelTest"])
|
||||||
|
all_test_names = list(map(lambda x: x[0], all_tests))
|
||||||
|
for test_name in test_to_enable:
|
||||||
|
assert test_name in all_test_names, "test name {} not found".format(test_name)
|
||||||
|
backend_test.include(r"^{}$".format(test_name))
|
||||||
|
|
||||||
|
|
||||||
|
def tearDownModule():
|
||||||
|
print()
|
||||||
|
print("*" * 40)
|
||||||
|
print("A total of {} tests should have run".format(len(test_to_enable)))
|
||||||
|
print("*" * 40)
|
||||||
|
|
||||||
|
|
||||||
|
# import all test cases at global scope to make them visible to python.unittest
|
||||||
|
globals().update(backend_test.test_cases)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue