Enable ONNX Backend Test (#1)

* wip, commit before merging with upstream

* organize API, return wrapped output

* enable onnx backend test

* undo unintentional commit

* fix krnl ops tablegen

* format krnl ops

* reorder fillDynMemRefWithMemRef to be after fillPtrToMemRefWithDynMemRef, better comments

* more onnx backend tests

* ensure that test names refer to existing tests

* improve code readability by shortening type names

* nit

* restore unintentional changes

* more nits

* fix ; -> :

* split runtime implementation into header and body file, add support for data types

* comment on the onnx backend test

* make the comments read better

* do not dump when lowering
This commit is contained in:
Tian Jin 2019-12-22 00:25:02 -05:00
parent 5573cb39fe
commit 685bf23b40
22 changed files with 972 additions and 106 deletions

View File

@ -26,6 +26,7 @@ add_subdirectory(third_party/pybind11)
set(CMAKE_CXX_STANDARD 14)
add_subdirectory(src/builder)
add_subdirectory(src/compiler)
add_subdirectory(src/runtime)
add_subdirectory(src)
add_subdirectory(test)

View File

@ -34,7 +34,7 @@ set(MLIR_SRC_INCLUDE_PATH ${LLVM_SRC}/projects/mlir/include)
set(MLIR_BIN_INCLUDE_PATH ${LLVM_BUILD}/projects/mlir/include)
set(MLIR_TOOLS_DIR ${LLVM_BUILD}/bin)
set(ONNF_TOOLS_DIR ${ONNF_BIN_ROOT}/src/compiler/tool/onnf_opt)
set(ONNF_TOOLS_DIR ${ONNF_BIN_ROOT}/bin)
set(ONNF_LIT_TEST_SRC_DIR ${CMAKE_SOURCE_DIR}/test/mlir)
set(ONNF_LIT_TEST_BUILD_DIR ${CMAKE_BINARY_DIR}/test/mlir)

View File

@ -614,7 +614,6 @@ private:
onnx::NodeProto node, int nIn, int nOut,
std::initializer_list<std::tuple<std::string, std::string, std::string>>
attrs) {
// Conv has attribute dilations, kernel_shape, pads, the default value of
// which is determined by the shape of first argument. However, since the
// shape is unknown now, these attributes can be not generated auto
@ -686,7 +685,7 @@ private:
}
void ImportGraph(const onnx::GraphProto &graph,
const std::string &name = "main") {
const std::string &name = "main_graph") {
// create a function for the graph
// TODO:
// * get name and type for the function.
@ -699,13 +698,18 @@ private:
}
// TODO: import the initializer
auto func_type = builder_.getFunctionType(arg_types, {});
auto main_func =
mlir::FuncOp::create(UnknownLoc(), name, func_type, /* attrs = */ {});
auto &entryBlock = *main_func.addEntryBlock();
auto funcType = builder_.getFunctionType(arg_types, {});
auto mainFunc =
mlir::FuncOp::create(UnknownLoc(), name, funcType, /* attrs = */ {});
auto entryPoint = mlir::ONNXEntryPointOp::create(
UnknownLoc(), mainFunc, /*numInputs=*/graph.input().size(),
/*numOutputs=*/graph.output().size());
auto &entryBlock = *mainFunc.addEntryBlock();
builder_.setInsertionPointToStart(&entryBlock);
module_.push_back(main_func);
module_.push_back(mainFunc);
module_.push_back(entryPoint);
for (auto it : llvm::zip(graph.input(), entryBlock.getArguments())) {
ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it));
@ -728,8 +732,8 @@ private:
builder_.create<mlir::ReturnOp>(UnknownLoc(), ret_vals);
// Update main function signature to reflect types of newly imported
// output tensors.
func_type = builder_.getFunctionType(arg_types, ret_types);
main_func.setType(func_type);
funcType = builder_.getFunctionType(arg_types, ret_types);
mainFunc.setType(funcType);
}
}; // FrontendGenImpl class
} // namespace

View File

@ -41,7 +41,8 @@ mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model);
* @return MLIR::module generated for the ONNX model.
*/
void ImportFrontendModelFile(std::string model_fname,
mlir::MLIRContext& context, mlir::OwningModuleRef& module);
mlir::MLIRContext &context,
mlir::OwningModuleRef &module);
/*!
* TODO: Import models into other extension dialects that cover the

View File

@ -0,0 +1,5 @@
add_library(DLCAnalysis STATIC
extract_integer_set.cpp)
target_include_directories(DLCAnalysis PRIVATE ${DLC_SRC_ROOT})
target_include_directories(DLCAnalysis PRIVATE ${DLC_BIN_ROOT})

View File

@ -364,6 +364,14 @@ ParseResult parseKrnlReturnLoopsOp(OpAsmParser &parser,
return success();
}
void KrnlEntryPointOp::build(mlir::Builder *builder, OperationState &state,
SymbolRefAttr funcAttr, IntegerAttr numInputs,
IntegerAttr numOutputs) {
state.addAttribute(KrnlEntryPointOp::getEntryPointFuncAttrName(), funcAttr);
state.addAttribute(KrnlEntryPointOp::getNumInputsAttrName(), numInputs);
state.addAttribute(KrnlEntryPointOp::getNumOutputsAttrName(), numOutputs);
}
#define GET_OP_CLASSES
#include "src/compiler/krnl.cpp.inc"
} // namespace mlir

View File

@ -11,6 +11,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"

View File

@ -8,35 +8,27 @@
include "mlir/IR/OpBase.td"
def Krnl_Dialect : Dialect {
let name = "krnl";
let cppNamespace = "";
}
// Require regions to have krnl.terminate terminator operation.
def ImplicitKrnlTerminator
: SingleBlockImplicitTerminator<"KrnlTerminatorOp">;
def ImplicitKrnlTerminator : SingleBlockImplicitTerminator<"KrnlTerminatorOp">;
def KrnlDefineLoopsOp : Op<Krnl_Dialect, "define_loops"> {
let summary = "define_loops operation";
let description = [{
The "krnl.define_loops" operation is used to define input loops,
those are the for loops appearing in the input program that we
intend to optimize.
}];
let arguments = (ins);
let results = (outs Variadic<AnyType>);
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"Builder *builder, OperationState &result,"
"int64_t num_loops">
];
let builders = [ OpBuilder<"Builder *builder, OperationState &result,"
"int64_t num_loops"> ];
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
@ -46,31 +38,25 @@ def KrnlDefineLoopsOp : Op<Krnl_Dialect, "define_loops"> {
// Helper function to extract the number of loops being defined.
int64_t getNumLoops() {
auto num_loops =
getAttrOfType<IntegerAttr>(
getNumLoopsAttrName())
auto num_loops = getAttrOfType<IntegerAttr>(getNumLoopsAttrName())
.getValue()
.getSExtValue();
return num_loops;
}
}];
}
def KrnlOptimizeLoopsOp : Op<Krnl_Dialect, "optimize_loops"> {
let summary = "optimize_loops operation";
let description = [{
The "krnl.optimize_loops" operation is essentially a cosmetic operation
which exists to encapsulate a region where loops are being scheduled/optimized.
which exists to encapsulate a region where loops are being scheduled /
optimized.
The optimized loops are returned at the end of the
region associated with the krnl.optimize_loops operation.
For example:
TBD once we have actual schedule intrinsics.
The optimized loops are returned at the end of the region associated with
the krnl.optimize_loops operation.
For example : TBD once we have actual schedule intrinsics.
}];
let arguments = (ins Variadic<AnyType>);
@ -79,10 +65,8 @@ def KrnlOptimizeLoopsOp : Op<Krnl_Dialect, "optimize_loops"> {
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"Builder *builder, OperationState &result, "
"int timestamp_space_rank">
];
let builders = [ OpBuilder<"Builder *builder, OperationState &result, "
"int timestamp_space_rank"> ];
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
@ -91,7 +75,6 @@ def KrnlOptimizeLoopsOp : Op<Krnl_Dialect, "optimize_loops"> {
def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> {
let summary = "iterate operation";
let description = [{
The "krnl.iterate" operation is conceptually equivalent to a nested for loops.
For instance, say we have the following two
@ -113,15 +96,10 @@ def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> {
}];
let arguments = (ins Variadic<AnyType>);
let regions = (region SizedRegion<1>:$bodyRegion);
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"Builder *builder, OperationState &result, "
"KrnlIterateOperandPack operandPack">
];
let builders = [ OpBuilder<"Builder *builder, OperationState &result, "
"KrnlIterateOperandPack operandPack"> ];
let extraClassDeclaration = [{
// In krnl.iterate operation, operands are stored as such
@ -134,8 +112,7 @@ def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> {
int64_t getNumOptimizedLoops() {
auto num_optimized_loops =
getAttrOfType<IntegerAttr>(
getNumOptimizedLoopsAttrName())
getAttrOfType<IntegerAttr>(getNumOptimizedLoopsAttrName())
.getValue()
.getSExtValue();
return num_optimized_loops;
@ -182,11 +159,30 @@ def KrnlTerminatorOp : Op<Krnl_Dialect, "terminate", [Terminator]> {
let verifier = ?;
}
def KrnlEntryPointOp : Op<Krnl_Dialect, "entry_point"> {
let summary = "Indicate ONNX entry point";
let description = [{The "krnl.entry_point" function indicates the main entry
point of ONNX model.}];
let builders = [ OpBuilder<"Builder *builder, OperationState &result, "
"SymbolRefAttr funcAttr, IntegerAttr numInputs, "
"IntegerAttr numOutputs"> ];
let extraClassDeclaration = [{
static StringRef getEntryPointFuncAttrName() { return "func"; }
static StringRef getNumInputsAttrName() { return "numInputs"; }
static StringRef getNumOutputsAttrName() { return "numOutputs"; }
}];
// No custom parsing/printing form.
let parser = ?;
let printer = ?;
}
def KrnlMemcpyOp : Op<Krnl_Dialect, "memcpy"> {
let summary = "Krnl memcpy operation";
let description = [{
In the KRNL dialect the reshape op doesn't generate a new memory entry and
treats a reshape like a cast.
In the KRNL dialect the reshape op
doesn't generate a new memory entry and treats a reshape like a cast.
}];
let arguments = (ins AnyMemRef:$dest, AnyMemRef:$src, AnyInteger:$size);

View File

@ -58,6 +58,27 @@ class ONNX_Op<string mnemonic, list<OpTrait> traits = []> :
include "dialect/onnx/onnxop.inc"
// Indicate entry point functions of ONNX graph.
def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
let summary = "Indicate ONNX entry point";
let description = [{
The "onnx.EntryPoint" function indicates the main entry point of ONNX model.
}];
let builders = [OpBuilder<[{Builder *builder, OperationState &state,
FuncOp function, int numInputs, int numOutputs}]>];
let extraClassDeclaration = [{
static ONNXEntryPointOp create(Location location, FuncOp& func,
int numInputs, int numOutputs);
static StringRef getEntryPointFuncAttrName() { return "func"; }
static StringRef getNumInputsAttrName() { return "numInputs"; }
static StringRef getNumOutputsAttrName() { return "numOutputs"; }
}];
}
def ONNXFullGemmOp: ONNX_Op<"FullGemm",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX general matrix multiply operation";

View File

@ -7,7 +7,6 @@
// This file defines ONNX operations in the MLIR operation set.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
@ -38,6 +37,28 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx)
>();
}
void ONNXEntryPointOp::build(mlir::Builder *builder,
mlir::OperationState &state, mlir::FuncOp function,
int numInputs, int numOutputs) {
state.addAttribute(ONNXEntryPointOp::getEntryPointFuncAttrName(),
builder->getSymbolRefAttr(function));
state.addAttribute(ONNXEntryPointOp::getNumInputsAttrName(),
builder->getI32IntegerAttr(numInputs));
state.addAttribute(ONNXEntryPointOp::getNumOutputsAttrName(),
builder->getI32IntegerAttr(numOutputs));
}
ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location,
mlir::FuncOp &func, int numInputs,
int numOutputs) {
mlir::OperationState state(location, "onnx.EntryPoint");
Builder builder(location->getContext());
mlir::ONNXEntryPointOp::build(&builder, state, func, numInputs, numOutputs);
Operation *op = mlir::Operation::create(state);
auto onnxEntryOp = llvm::cast<mlir::ONNXEntryPointOp>(op);
return onnxEntryOp;
}
//===----------------------------------------------------------------------===//
// ONNX Operations
//===----------------------------------------------------------------------===//

View File

@ -10,6 +10,7 @@
#pragma once
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"

View File

@ -8,7 +8,6 @@
// Krnl IR and standard operations.
//
//===----------------------------------------------------------------------===//
#include <map>
#include "mlir/Dialect/AffineOps/AffineOps.h"
@ -884,6 +883,27 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
}
};
//===----------------------------------------------------------------------===//
// EntryPoint Op lowering to Krnl Entry Point.
//===----------------------------------------------------------------------===//
class ONNXEntryPointLowering : public OpRewritePattern<ONNXEntryPointOp> {
public:
using OpRewritePattern<ONNXEntryPointOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(ONNXEntryPointOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<KrnlEntryPointOp>(
op,
op.getAttrOfType<SymbolRefAttr>(
ONNXEntryPointOp::getEntryPointFuncAttrName()),
op.getAttrOfType<IntegerAttr>(ONNXEntryPointOp::getNumInputsAttrName()),
op.getAttrOfType<IntegerAttr>(
ONNXEntryPointOp::getNumOutputsAttrName()));
return matchSuccess();
}
};
//===----------------------------------------------------------------------===//
// Conversion from Tensor type to the Standard dialect MemRef type.
//===----------------------------------------------------------------------===//
@ -985,7 +1005,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
ONNXReshapeOpLowering>(&getContext());
ONNXReshapeOpLowering, ONNXEntryPointLowering>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`

View File

@ -143,15 +143,17 @@ void KrnlToAffineLoweringPass::runOnFunction() {
// We expect IR to be free of Krnl Dialect Ops.
target.addIllegalDialect<KrnlOpsDialect>();
target.addLegalOp<KrnlMemcpyOp>();
target.addLegalOp<KrnlEntryPointOp>();
OwningRewritePatternList patterns;
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
KrnlDefineLoopsLowering, KrnlOptimizeLoopsLowering>(
&getContext());
if (failed(applyPartialConversion(getFunction(), target, patterns)))
if (failed(applyPartialConversion(getFunction(), target, patterns))) {
signalPassFailure();
}
}
} // namespace

View File

@ -4,7 +4,6 @@
//
//===----------------------------------------------------------------------===//
#include "llvm/ADT/Sequence.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
@ -15,6 +14,7 @@
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/Sequence.h"
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
#include "src/compiler/pass/passes.hpp"
@ -23,8 +23,26 @@ using namespace mlir;
namespace {
static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName,
ModuleOp module,
mlir::LLVM::LLVMType funcType,
PatternRewriter &rewriter) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>(funcName)) {
auto symbolRef = SymbolRefAttr::get(funcName, context);
assert(symbolRef.getType() == funcType && "wrong symbol type");
return symbolRef;
}
// Insert the function into the body of the parent module.
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), funcName, funcType);
return SymbolRefAttr::get(funcName, context);
}
//===----------------------------------------------------------------------===//
// KRNL to LLVM: patterns which need a direct lowering to LLVM.
// KRNL to LLVM: KrnlMemcpyOpLowering
//===----------------------------------------------------------------------===//
class KrnlMemcpyOpLowering : public ConversionPattern {
@ -32,7 +50,8 @@ class KrnlMemcpyOpLowering : public ConversionPattern {
explicit KrnlMemcpyOpLowering(MLIRContext *context)
: ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto *context = op->getContext();
auto loc = op->getLoc();
@ -65,8 +84,8 @@ class KrnlMemcpyOpLowering : public ConversionPattern {
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]);
// Memcpy call
rewriter.create<CallOp>(loc, memcpyRef,
LLVM::LLVMType::getVoidTy(llvmDialect),
rewriter.create<CallOp>(
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
ArrayRef<Value *>(
{alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size}));
@ -78,7 +97,8 @@ class KrnlMemcpyOpLowering : public ConversionPattern {
/// Return a symbol reference to the memcpy function, inserting it into the
/// module if necessary.
static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter,
ModuleOp module, LLVM::LLVMDialect* llvmDialect) {
ModuleOp module,
LLVM::LLVMDialect *llvmDialect) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64"))
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
@ -87,18 +107,354 @@ class KrnlMemcpyOpLowering : public ConversionPattern {
auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect);
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy,
auto llvmFnType = LLVM::LLVMType::getFunctionTy(
llvmVoidTy,
ArrayRef<mlir::LLVM::LLVMType>({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty}),
false);
// Insert the memcpy function into the body of the parent module.
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(
module.getLoc(), "llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(),
"llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
}
};
//===----------------------------------------------------------------------===//
// KRNL to LLVM: KrnlEntryPointOp
//===----------------------------------------------------------------------===//
class KrnlEntryPointOpLowering : public OpRewritePattern<KrnlEntryPointOp> {
public:
using OpRewritePattern<KrnlEntryPointOp>::OpRewritePattern;
enum class API {
CREATE_ORDERED_DYN_MEM_REF_DICT,
CREATE_DYN_MEM_REF,
GET_DYN_MEM_REF,
SET_DYN_MEM_REF,
GET_DATA,
SET_DATA,
GET_SIZES,
GET_STRIDES,
};
struct ApiSpec {
API id;
std::string name;
FlatSymbolRefAttr symbolRef;
LLVM::LLVMType outputTy;
SmallVector<LLVM::LLVMType, 4> inputTys;
ApiSpec(API id, const std::string &name, LLVM::LLVMType outputTy,
ArrayRef<LLVM::LLVMType> inputTys)
: id(id), name(name), outputTy(outputTy),
inputTys(inputTys.begin(), inputTys.end()) {}
LLVM::LLVMType funcTy() {
return LLVM::LLVMType::getFunctionTy(outputTy, inputTys,
/*isVarArg=*/false);
}
};
PatternMatchResult matchAndRewrite(KrnlEntryPointOp op,
PatternRewriter &rewriter) const override {
auto *llvmDialect =
op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
assert(llvmDialect && "expected llvm dialect to be registered");
auto module = op.getParentOfType<ModuleOp>();
auto apiRegistry = RegisterAllApis(module, rewriter, llvmDialect);
auto loc = op.getLoc();
auto numOutputs =
op.getAttrOfType<IntegerAttr>(KrnlEntryPointOp::getNumOutputsAttrName())
.getInt();
using LLVMType = LLVM::LLVMType;
auto opaquePtrTy = LLVMType::getInt8PtrTy(llvmDialect);
auto int32Ty = LLVMType::getInt32Ty(llvmDialect);
// Rewrite Krnl Entry Point Operation to an LLVM function with a dynamic
// signature. The signature is dynamic because it remains the same no matter
// what the model input/output schema look like. Such dynamic signature
// takes a opaque ptr as input, representing a ptr to a data structure
// containing a set of dynamic memrefs wrapped in a vector; similarly the
// output is also a opaque ptr to a data structure with output memrefs
// wrapped within it.
auto staticEntryPointFuncName =
op.getAttrOfType<SymbolRefAttr>(
KrnlEntryPointOp::getEntryPointFuncAttrName())
.getLeafReference();
auto dynEntryPointName = "_dyn_entry_point_" + staticEntryPointFuncName;
assert(module.lookupSymbol(dynEntryPointName.str()) == nullptr &&
"dynamic entry point name is not unique");
rewriter.eraseOp(op);
auto dynEntryPointFuncTy =
LLVMType::getFunctionTy(opaquePtrTy, {opaquePtrTy}, false);
auto dynamicEntryPointFunc = rewriter.create<LLVM::LLVMFuncOp>(
loc, dynEntryPointName.str(), dynEntryPointFuncTy);
auto &entryPointEntryBlock =
createEntryBlock(dynEntryPointFuncTy, dynamicEntryPointFunc);
rewriter.setInsertionPointToStart(&entryPointEntryBlock);
// Based on the static entry point type signature, unpack dynamic memory
// refs to corresponding static memory refs.
auto *staticEntryPointFunc = module.lookupSymbol(staticEntryPointFuncName);
assert(staticEntryPointFunc &&
isa<LLVM::LLVMFuncOp>(staticEntryPointFunc) &&
"entry point func must exist and be an llvm func op");
auto staticEntryPointTy = dyn_cast<LLVM::LLVMFuncOp>(staticEntryPointFunc)
.getType()
.dyn_cast<LLVMType>();
// Retrieve dynamic mem refs from wrapped input, and convert every one of
// them to static mem refs.
SmallVector<Value *, 4> staticInputs;
auto wrappedInput = entryPointEntryBlock.getArgument(0);
for (size_t i = 0; i < staticEntryPointTy.getFunctionNumParams(); i++) {
// Call API function to retrieve the i-th dynamic memref.
auto idxVal = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(i));
auto dynMemRef = callApi(rewriter, loc, apiRegistry, API::GET_DYN_MEM_REF,
{wrappedInput, idxVal});
// Create a (static) memref type corresponding to the i-th memref input to
// the inference function on stack, and load it to memRef.
auto memRefPtrTy = staticEntryPointTy.getFunctionParamType(i);
auto memRefTy = memRefPtrTy.getPointerElementTy();
auto one = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(1));
Value *ptrToMemRef =
rewriter.create<LLVM::AllocaOp>(loc, memRefPtrTy, one,
/*alignment=*/0);
// Fill in the memref underlying ptrToMemRef with information extracted
// from dynMemRef.
fillPtrToMemRefWithDynMemRef(*dynMemRef, *ptrToMemRef, rewriter, loc,
apiRegistry, llvmDialect);
// ptrToMemRef will be an input to main computation graph function.
staticInputs.emplace_back(ptrToMemRef);
}
// If more than one output exists, the struct becomes a nested struct,
// the unpacking logic can be more involved, so no support for now.
assert(numOutputs == 1 && "only support 1 output tensor now.");
// Call static entry point with the memref ptrs created, and get output.
auto outputMemRefs = rewriter.create<LLVM::CallOp>(
loc, staticEntryPointTy.getFunctionResultType(),
rewriter.getSymbolRefAttr(staticEntryPointFuncName), staticInputs);
// Create wrapped output.
auto wrappedOutput = callApi(rewriter, loc, apiRegistry,
API::CREATE_ORDERED_DYN_MEM_REF_DICT, {});
// Get the first memref returned, convert to a dynamic memref and store
// it in the wrapped Output.
auto outMemRef = outputMemRefs.getResult(0);
auto outMemRefTy = outMemRef->getType().dyn_cast<LLVMType>();
auto outMemRefRank =
outMemRefTy.getStructElementType(3).getArrayNumElements();
auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
fillDynMemRefWithMemRef(*outMemRef, *outDynMemRef, rewriter, loc,
apiRegistry, llvmDialect);
auto zero = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(0));
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
{wrappedOutput, zero, outDynMemRef});
// Return wrapped output.
rewriter.create<LLVM::ReturnOp>(loc,
SmallVector<Value *, 1>({wrappedOutput}));
return matchSuccess();
}
private:
using ApiRegistry = std::map<API, ApiSpec>;
ApiRegistry RegisterAllApis(ModuleOp &module, PatternRewriter &rewriter,
LLVM::LLVMDialect *llvmDialect) const {
using LLVMType = LLVM::LLVMType;
auto voidTy = LLVMType::getVoidTy(llvmDialect);
auto opaquePtrTy = LLVMType::getInt8PtrTy(llvmDialect);
auto int32Ty = LLVMType::getInt32Ty(llvmDialect);
auto int64Ty = LLVMType::getInt64Ty(llvmDialect);
auto int64PtrTy = int64Ty.getPointerTo();
// Declare API type as an enum value, its string name and an LLVM Type
// specifying its signature.
// clang-format off
std::vector<ApiSpec> apiSpecs = {
ApiSpec(API::CREATE_ORDERED_DYN_MEM_REF_DICT, "createOrderedDynMemRefDict", opaquePtrTy, {}),
ApiSpec(API::CREATE_DYN_MEM_REF, "createDynMemRef", opaquePtrTy, {int32Ty}),
ApiSpec(API::GET_DATA, "getData", opaquePtrTy, {opaquePtrTy}),
ApiSpec(API::SET_DATA, "setData", voidTy, {opaquePtrTy, opaquePtrTy}),
ApiSpec(API::GET_DYN_MEM_REF, "getDynMemRef", opaquePtrTy, {opaquePtrTy, int32Ty}),
ApiSpec(API::SET_DYN_MEM_REF, "setDynMemRef", voidTy, {opaquePtrTy, int32Ty, opaquePtrTy}),
ApiSpec(API::GET_SIZES, "getSizes", int64PtrTy, {opaquePtrTy}),
ApiSpec(API::GET_STRIDES, "getStrides", int64PtrTy, {opaquePtrTy})
};
// clang-format on
// Declare APIs in the current module and build an API registry mapping api
// identities to a symbol reference to the API function.
ApiRegistry registry;
for (auto &apiSpec : apiSpecs) {
apiSpec.symbolRef = getOrInsertExternFunc(apiSpec.name, module,
apiSpec.funcTy(), rewriter);
registry.emplace(apiSpec.id, apiSpec);
}
return registry;
}
// Call a registered API, return the return SSA values if only one result is
// returned, otherwise return nullptr.
Value *callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry,
API apiId, ArrayRef<Value *> params) const {
auto returnVals = rewriter.create<LLVM::CallOp>(
loc, registry.at(apiId).outputTy, registry.at(apiId).symbolRef,
ArrayRef<Value *>(params));
if (returnVals.getNumResults() == 1)
return returnVals.getResult(0);
return nullptr;
}
// Helper function to insert an entry block to LLVM function.
// (TODO): upstream this to MLIR.
Block &createEntryBlock(LLVM::LLVMType &dynEntryPointFuncType,
LLVM::LLVMFuncOp &dynamicEntryPointFunc) const {
// Add entry block:
auto *entryPointEntryBlock = new Block();
dynamicEntryPointFunc.push_back(entryPointEntryBlock);
llvm::SmallVector<Type, 4> argTypes;
for (size_t i = 0; i < dynEntryPointFuncType.getFunctionNumParams(); i++)
argTypes.emplace_back(dynEntryPointFuncType.getFunctionParamType(i));
entryPointEntryBlock->addArguments(argTypes);
return *entryPointEntryBlock;
}
void fillPtrToMemRefWithDynMemRef(Value &dynMemRef, Value &ptrToMemRef,
PatternRewriter &rewriter,
const Location &loc,
const std::map<API, ApiSpec> &apiRegistry,
LLVM::LLVMDialect *llvmDialect) const {
auto memRefPtrTy = ptrToMemRef.getType().dyn_cast<LLVM::LLVMType>();
auto memRefTy = memRefPtrTy.getPointerElementTy();
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
Value *memRef =
rewriter.create<LLVM::LoadOp>(loc, memRefPtrTy, &ptrToMemRef);
// Set dataPtr and alignedDataPtr;
auto dataPtr =
callApi(rewriter, loc, apiRegistry, API::GET_DATA, {&dynMemRef});
dataPtr = rewriter.create<LLVM::BitcastOp>(
loc, memRefTy.getStructElementType(0), dataPtr);
memRef = rewriter.create<LLVM::InsertValueOp>(
loc, memRefTy, memRef, dataPtr,
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(0)}));
memRef = rewriter.create<LLVM::InsertValueOp>(
loc, memRefTy, memRef, dataPtr,
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(1)}));
// Use zero offset now.
auto zero = rewriter.create<LLVM::ConstantOp>(
loc, int64Ty, rewriter.getI64IntegerAttr(0));
memRef = rewriter.create<LLVM::InsertValueOp>(
loc, memRefTy, memRef, zero,
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(2)}));
// Get rank, sizes array ptr and strides array ptr.
auto rank = memRefTy.getStructElementType(3).getArrayNumElements();
auto sizesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {&dynMemRef});
auto stridesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {&dynMemRef});
for (decltype(rank) i = 0; i < rank; i++) {
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
loc, int64Ty, rewriter.getI64IntegerAttr(i));
// Insert size of the dimension.
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
loc, int64Ty.getPointerTo(), sizesArrayPtr,
ArrayRef<Value *>({dimIdx}));
auto dimSize = rewriter.create<LLVM::LoadOp>(loc, int64Ty.getPointerTo(),
dimSizePtr);
memRef = rewriter.create<LLVM::InsertValueOp>(
loc, memRefTy, memRef, dimSize,
rewriter.getArrayAttr(
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
// Insert stride of the dimension.
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
loc, int64Ty.getPointerTo(), sizesArrayPtr,
ArrayRef<Value *>({dimIdx}));
auto dimStride = rewriter.create<LLVM::LoadOp>(
loc, int64Ty.getPointerTo(), dimStridePtr);
memRef = rewriter.create<LLVM::InsertValueOp>(
loc, memRefTy, memRef, dimStride,
rewriter.getArrayAttr(
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
}
rewriter.create<LLVM::StoreOp>(loc, memRef, &ptrToMemRef);
}
void fillDynMemRefWithMemRef(Value &outMemRef, Value &outDynMemRef,
PatternRewriter &rewriter, const Location &loc,
const std::map<API, ApiSpec> &apiRegistry,
LLVM::LLVMDialect *llvmDialect) const {
auto outMemRefTy = outMemRef.getType().dyn_cast<LLVM::LLVMType>();
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
// Extract the data pointer, and record it in dynamic mem ref created.
Value *outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(
loc, outMemRefTy.getStructElementType(0), &outMemRef,
rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)}));
outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr);
callApi(rewriter, loc, apiRegistry, API::SET_DATA,
{&outDynMemRef, outMemRefDataPtr});
auto rank = outMemRefTy.getStructElementType(3).getArrayNumElements();
auto sizesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {&outDynMemRef});
auto stridesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {&outDynMemRef});
for (decltype(rank) i = 0; i < rank; i++) {
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
loc, int64Ty, rewriter.getI64IntegerAttr(i));
// Transfer size of dimension from memref to dynamic memref.
auto dimSize = rewriter.create<LLVM::ExtractValueOp>(
loc, int64Ty, &outMemRef,
rewriter.getArrayAttr(
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
loc, int64Ty.getPointerTo(), sizesArrayPtr,
ArrayRef<Value *>({dimIdx}));
rewriter.create<LLVM::StoreOp>(loc, dimSize, dimSizePtr);
// Transfer stride of dimension from memref to dynamic memref.
auto dimStride = rewriter.create<LLVM::ExtractValueOp>(
loc, int64Ty, &outMemRef,
rewriter.getArrayAttr(
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
loc, int64Ty.getPointerTo(), stridesArrayPtr,
ArrayRef<Value *>({dimIdx}));
rewriter.create<LLVM::StoreOp>(loc, dimStride, dimStridePtr);
}
}
};
} // end namespace
//===----------------------------------------------------------------------===//
@ -128,12 +484,13 @@ void KrnlToLLVMLoweringPass::runOnModule() {
populateStdToLLVMConversionPatterns(typeConverter, patterns);
// Lower from the `krnl` dialect i.e. the Reshape operation.
patterns.insert<KrnlMemcpyOpLowering>(&getContext());
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering>(
&getContext());
// We want to completely lower to LLVM, so we use a `FullConversion`. This
// ensures that only legal operations will remain after the conversion.
auto module = getModule();
if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
if (failed(
applyFullConversion(getModule(), target, patterns, &typeConverter)))
signalPassFailure();
}
@ -142,5 +499,5 @@ std::unique_ptr<mlir::Pass> mlir::createKrnlLowerToLLVMPass() {
return std::make_unique<KrnlToLLVMLoweringPass>();
}
static PassRegistration<KrnlToLLVMLoweringPass> pass(
"lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM.");
static PassRegistration<KrnlToLLVMLoweringPass>
pass("lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM.");

View File

@ -7,10 +7,7 @@
//===----------------------------------------------------------------------===//
#include <cmath>
#include <cstdlib>
#include <iostream>
#include <random>
#include <tuple>
#include "llvm/Bitcode/BitcodeWriter.h"
#include "llvm/Support/CommandLine.h"
@ -24,9 +21,7 @@
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
#include "src/compiler/pass/passes.hpp"
#include "mlir/Analysis/Verifier.h"
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/MLIRContext.h"

View File

@ -0,0 +1,17 @@
add_library(cruntime
dyn_memref.cpp
dyn_memref.h
data_type.h)
target_include_directories(cruntime
PRIVATE ${DLC_SRC_ROOT} ${DLC_BIN_ROOT}
${DLC_SRC_ROOT})
pybind11_add_module(pyruntime
dyn_memref.cpp
dyn_memref.h
runtime.cpp
runtime.hpp)
target_link_libraries(pyruntime PRIVATE ${CMAKE_DL_LIBS})
target_include_directories(pyruntime
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
${ONNF_SRC_ROOT})

30
src/runtime/data_type.h Normal file
View File

@ -0,0 +1,30 @@
enum DYN_MEMREF_DATA_TYPE {
UNDEFINED = 0;
// Basic types.
FLOAT = 1; // float
UINT8 = 2; // uint8_t
INT8 = 3; // int8_t
UINT16 = 4; // uint16_t
INT16 = 5; // int16_t
INT32 = 6; // int32_t
INT64 = 7; // int64_t
STRING = 8; // string
BOOL = 9; // bool
// IEEE754 half-precision floating-point format (16 bits wide).
// This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
FLOAT16 = 10;
DOUBLE = 11;
UINT32 = 12;
UINT64 = 13;
COMPLEX64 = 14; // complex with float32 real and imaginary components
COMPLEX128 = 15; // complex with float64 real and imaginary components
// Non-IEEE floating-point format based on IEEE754 single-precision
// floating-point number truncated to 16 bits.
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
BFLOAT16 = 16;
// Future extensions go here.
};

View File

@ -0,0 +1,74 @@
#include <cassert>
#include <map>
#include <string>
#include <vector>
#include "dyn_memref.h"
DynMemRef::DynMemRef(int _rank) {
rank = _rank;
sizes = (INDEX_TYPE *)malloc(rank * sizeof(INDEX_TYPE));
strides = (int64_t *)malloc(rank * sizeof(int64_t));
}
// An ordered dynamic MemRef dictionary.
// The goal is to support accessing dynamic memory ref by name and by index.
// Currently, only accessing by index is supported.
struct OrderedDynMemRefDict {
std::map<std::string, DynMemRef *> tensorDict;
std::vector<std::string> orderedNames;
};
int numDynMemRefs(OrderedDynMemRefDict *dict) {
return dict->orderedNames.size();
}
OrderedDynMemRefDict *createOrderedDynMemRefDict() {
return new OrderedDynMemRefDict();
}
DynMemRef *createDynMemRef(int rank) { return new DynMemRef(rank); }
DynMemRef *getDynMemRef(OrderedDynMemRefDict *tensorDict, int idx) {
return tensorDict->tensorDict[tensorDict->orderedNames[idx]];
}
void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx,
DynMemRef *tensor) {
if (tensorDict->orderedNames.capacity() <= idx)
tensorDict->orderedNames.resize(idx + 1);
// The dynamic memref is essentially anonymous, since we are storing it by
// indexed position.
// TODO: can use random string as names to reduce chance of collision.
auto unique_name = std::to_string(idx);
assert(tensorDict->tensorDict.count(unique_name) == 0 &&
"duplicate dynamic mem ref name");
tensorDict->orderedNames[idx] = unique_name;
tensorDict->tensorDict[tensorDict->orderedNames[idx]] = tensor;
}
void *getData(DynMemRef *dynMemRef) { return dynMemRef->data; }
void setData(DynMemRef *dynMemRef, void *dataPtr) { dynMemRef->data = dataPtr; }
void *getAlignedData(DynMemRef *dynMemRef) { return dynMemRef->alignedData; }
void setAlignedData(DynMemRef *dynMemRef, void *dataPtr) {
dynMemRef->alignedData = dataPtr;
}
INDEX_TYPE *getSizes(DynMemRef *dynMemRef) { return dynMemRef->sizes; }
void setSizes(DynMemRef *dynMemRef, INDEX_TYPE *sizes) {
for (int i = 0; i < dynMemRef->rank; i++)
dynMemRef->sizes[i] = sizes[i];
}
int64_t *getStrides(DynMemRef *dynMemRef) { return dynMemRef->strides; }
void setStrides(DynMemRef *dynMemRef, int64_t *strides) {
for (int i = 0; i < dynMemRef->rank; i++)
dynMemRef->sizes[i] = strides[i];
}

61
src/runtime/dyn_memref.h Normal file
View File

@ -0,0 +1,61 @@
#pragma once
#include <cstdint>
typedef int64_t INDEX_TYPE;
// This is a dynamic version of memref.
// The same struct can be used to represent memrefs of
// all ranks and type combinations.
struct DynMemRef {
void *data;
void *alignedData;
INDEX_TYPE offset;
unsigned int rank;
INDEX_TYPE *sizes;
int64_t *strides;
DynMemRef(int _rank);
};
extern "C" {
// Ordered DynMemRef Dictionary is a data structure for wrapping the input
// dynmemrefs so that they can be addressed both by index and by name.
struct OrderedDynMemRefDict;
// Get number of dynamic memrefs in OrderedDynMemRefDict dict.
int numDynMemRefs(OrderedDynMemRefDict *dict);
// Create an ordered dynmemref dictionary.
OrderedDynMemRefDict *createOrderedDynMemRefDict();
// Create a dynmemref with a certain rank.
DynMemRef *createDynMemRef(int rank);
// Get the i-th dynmemref from orderedDict.
DynMemRef *getDynMemRef(OrderedDynMemRefDict *orderedDict, int i);
// Set the i-th dynmemref in orderedDict to be dynMemRef.
void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx,
DynMemRef *dynMemRef);
// Get data pointer from dynMemRef.
void *getData(DynMemRef *dynMemRef);
// Set data pointer for dynMemRef.
void setData(DynMemRef *dynMemRef, void *data);
// Get algined data pointer from dynMemRef.
void *getAlignedData(DynMemRef *);
// Set aligned data pointer for dynMemRef.
void setAlignedData(DynMemRef *, void *);
// Get ptr to sizes array.
INDEX_TYPE *getSizes(DynMemRef *);
// Get ptr to strides array.
int64_t *getStrides(DynMemRef *);
}

52
src/runtime/runtime.cpp Normal file
View File

@ -0,0 +1,52 @@
#include "runtime.hpp"
ExecutionSession::ExecutionSession(std::string sharedLibPath,
std::string entryPointName) {
_sharedLibraryHandle = dlopen(sharedLibPath.c_str(), RTLD_LAZY);
_entryPointFunc =
(entryPointFuncType)dlsym(_sharedLibraryHandle, entryPointName.c_str());
}
std::vector<py::array>
ExecutionSession::run(std::vector<py::array> inputsPyArray) {
assert(_entryPointFunc && "entry point not loaded");
auto *wrappedInput = createOrderedDynMemRefDict();
int inputIdx = 0;
for (auto inputPyArray : inputsPyArray) {
auto *inputDynMemRef = createDynMemRef(inputPyArray.ndim());
assert(inputPyArray.flags() && py::array::c_style &&
"expect contiguous python array");
if (inputPyArray.writeable()) {
inputDynMemRef->data = inputPyArray.mutable_data();
inputDynMemRef->alignedData = inputPyArray.mutable_data();
} else {
// If data is not writable, copy them to a writable buffer.
auto *copiedData = (float *)malloc(inputPyArray.nbytes());
memcpy(copiedData, inputPyArray.data(), inputPyArray.nbytes());
inputDynMemRef->data = copiedData;
inputDynMemRef->alignedData = copiedData;
}
for (int i = 0; i < inputPyArray.ndim(); i++) {
inputDynMemRef->sizes[i] = inputPyArray.shape(i);
inputDynMemRef->strides[i] = inputPyArray.strides(i);
}
setDynMemRef(wrappedInput, inputIdx++, inputDynMemRef);
}
std::vector<py::array> outputPyArrays;
auto *wrappedOutput = _entryPointFunc(wrappedInput);
for (int i = 0; i < numDynMemRefs(wrappedOutput); i++) {
auto *dynMemRef = getDynMemRef(wrappedOutput, i);
auto shape = std::vector<int64_t>(dynMemRef->sizes,
dynMemRef->sizes + dynMemRef->rank);
outputPyArrays.emplace_back(
py::array(py::dtype("float32"), shape, dynMemRef->data));
}
return outputPyArrays;
}
ExecutionSession::~ExecutionSession() { dlclose(_sharedLibraryHandle); }

37
src/runtime/runtime.hpp Normal file
View File

@ -0,0 +1,37 @@
#pragma once
#include <cassert>
#include <string>
#include <dlfcn.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "src/runtime/dyn_memref.h"
namespace py = pybind11;
typedef OrderedDynMemRefDict *(*entryPointFuncType)(OrderedDynMemRefDict *);
class ExecutionSession {
public:
ExecutionSession(std::string sharedLibPath, std::string entryPointName);
std::vector<py::array> run(std::vector<py::array> inputsPyArray);
~ExecutionSession();
private:
// Handler to the shared library file being loaded.
void *_sharedLibraryHandle = nullptr;
// Entry point function.
entryPointFuncType _entryPointFunc = nullptr;
};
PYBIND11_MODULE(pyruntime, m) {
py::class_<ExecutionSession>(m, "ExecutionSession")
.def(py::init<const std::string &, const std::string &>())
.def("run", &ExecutionSession::run);
}

162
test/onnx_backend_test.py Normal file
View File

@ -0,0 +1,162 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import itertools
import os
import unittest
import onnx.backend.base
import onnx.backend.test
from onnx.backend.base import Device, DeviceType
import onnx.shape_inference
import onnx.version_converter
from typing import Optional, Text, Any, Tuple, Sequence
from onnx import NodeProto, ModelProto, TensorProto
import numpy # type: ignore
import subprocess
from pyruntime import ExecutionSession
CXX = os.getenv('CXX')
FE = os.getenv('FE')
LLC = os.getenv('LLC')
RT_DIR = os.getenv('RT_DIR')
assert CXX and FE and LLC and RT_DIR, "tools path not set"
class DummyBackend(onnx.backend.base.Backend):
@classmethod
def prepare(
cls,
model,
device='CPU',
**kwargs
):
super(DummyBackend, cls).prepare(model, device, **kwargs)
# Save model to disk as temp_model.onnx.
onnx.save(model, "temp_model.onnx")
# Call frontend to process temp_model.onnx, bit code will be generated.
subprocess.run([FE, "temp_model.onnx"], stdout=subprocess.PIPE)
# Call llc to generate object file from bitcode.
subprocess.run([LLC, "-filetype=obj", "model.bc"],
stdout=subprocess.PIPE)
# Generate shared library from object file, linking with c runtime.
subprocess.run([
CXX, "-shared", "model.o", "-o", "model.so", "-L" + RT_DIR,
"-lcruntime"
],
stdout=subprocess.PIPE)
return ExecutionSession("./model.so", "_dyn_entry_point_main_graph")
@classmethod
def supports_device(cls, device):
d = Device(device)
if d.type == DeviceType.CPU:
return True
return False
backend_test = onnx.backend.test.BackendTest(DummyBackend, __name__)
# Test directories:
# https://github.com/onnx/onnx/tree/master/onnx/backend/test/data/node
test_to_enable = [
# Add Op:
"test_add_cpu",
"test_add_bcast_cpu",
# And Op:
# Sub Op:
"test_sub_cpu",
"test_sub_bcast_cpu",
"test_sub_example_cpu",
# Cosh Op:
"test_cosh_cpu",
"test_cosh_example_cpu",
# Div Op:
"test_div_cpu",
"test_div_bcast_cpu",
"test_div_example_cpu",
# Elu Op:
"test_elu_cpu",
"test_elu_default_cpu",
"test_elu_example_cpu",
# Exp Op:
"test_exp_cpu",
"test_exp_example_cpu",
# Hard Sigmoid Op:
"test_hardsigmoid_cpu",
"test_hardsigmoid_default_cpu",
"test_hardsigmoid_example_cpu",
# Leaky Relu Op:
"test_leakyrelu_cpu",
"test_leakyrelu_default_cpu",
"test_leakyrelu_example_cpu",
# Max Op:
# "test_max_example_cpu", <- error
"test_max_one_input_cpu",
# "test_max_two_inputs_cpu", <- error
# Min Op:
# "test_min_example_cpu", <- error
"test_min_one_input_cpu",
# "test_min_two_inputs_cpu", <- error
# Mul Op:
"test_mul_cpu",
"test_mul_bcast_cpu",
"test_mul_example_cpu",
# Relu Op:
"test_relu_cpu",
# Selu Op:
"test_selu_cpu",
"test_selu_default_cpu",
"test_selu_example_cpu",
# Sigmoid Op:
"test_sigmoid_cpu",
"test_sigmoid_example_cpu",
# Sum Op:
#"test_sum_example_cpu", <- error
"test_sum_one_input_cpu",
#"test_sum_two_inputs_cpu", <- error
# Reciprocal Op:
#"test_reciprocal_cpu", <- error on shape inference.
#"test_reciprocal_example_cpu", <- error on shape inference.
]
# Extract name of all test cases.
import inspect
all_tests = inspect.getmembers(
backend_test.test_cases["OnnxBackendNodeModelTest"])
all_test_names = list(map(lambda x: x[0], all_tests))
for test_name in test_to_enable:
assert test_name in all_test_names, "test name {} not found".format(test_name)
backend_test.include(r"^{}$".format(test_name))
def tearDownModule():
print()
print("*" * 40)
print("A total of {} tests should have run".format(len(test_to_enable)))
print("*" * 40)
# import all test cases at global scope to make them visible to python.unittest
globals().update(backend_test.test_cases)
if __name__ == '__main__':
unittest.main()