From 685bf23b40bd55eb77452d69d564293af3478dfe Mon Sep 17 00:00:00 2001 From: Tian Jin Date: Sun, 22 Dec 2019 00:25:02 -0500 Subject: [PATCH] Enable ONNX Backend Test (#1) * wip, commit before merging with upstream * organize API, return wrapped output * enable onnx backend test * undo unintentional commit * fix krnl ops tablegen * format krnl ops * reorder fillDynMemRefWithMemRef to be after fillPtrToMemRefWithDynMemRef, better comments * more onnx backend tests * ensure that test names refer to existing tests * improve code readability by shortening type names * nit * restore unintentional changes * more nits * fix ; -> : * split runtime implementation into header and body file, add support for data types * comment on the onnx backend test * make the comments read better * do not dump when lowering --- CMakeLists.txt | 1 + MLIR.cmake | 2 +- src/builder/frontend_dialect_transformer.cpp | 22 +- src/builder/frontend_dialect_transformer.hpp | 5 +- src/compiler/analysis/CMakeLists.txt | 5 + src/compiler/dialect/krnl/krnl_ops.cpp | 8 + src/compiler/dialect/krnl/krnl_ops.hpp | 1 + src/compiler/dialect/krnl/krnl_ops.td | 106 +++-- src/compiler/dialect/onnx/onnx.td | 21 + src/compiler/dialect/onnx/onnx_ops.cpp | 23 +- src/compiler/dialect/onnx/onnx_ops.hpp | 1 + src/compiler/pass/lower_frontend_to_krnl.cpp | 24 +- src/compiler/transform/lower_krnl.cpp | 4 +- src/compiler/transform/lower_to_llvm.cpp | 417 +++++++++++++++++-- src/main.cpp | 5 - src/runtime/CMakeLists.txt | 17 + src/runtime/data_type.h | 30 ++ src/runtime/dyn_memref.cpp | 74 ++++ src/runtime/dyn_memref.h | 61 +++ src/runtime/runtime.cpp | 52 +++ src/runtime/runtime.hpp | 37 ++ test/onnx_backend_test.py | 162 +++++++ 22 files changed, 972 insertions(+), 106 deletions(-) create mode 100644 src/compiler/analysis/CMakeLists.txt create mode 100644 src/runtime/CMakeLists.txt create mode 100644 src/runtime/data_type.h create mode 100644 src/runtime/dyn_memref.cpp create mode 100644 src/runtime/dyn_memref.h create mode 100644 src/runtime/runtime.cpp create mode 100644 src/runtime/runtime.hpp create mode 100644 test/onnx_backend_test.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 7f4c5fc..7cc613d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,6 +26,7 @@ add_subdirectory(third_party/pybind11) set(CMAKE_CXX_STANDARD 14) add_subdirectory(src/builder) add_subdirectory(src/compiler) +add_subdirectory(src/runtime) add_subdirectory(src) add_subdirectory(test) diff --git a/MLIR.cmake b/MLIR.cmake index 68b4447..7200493 100644 --- a/MLIR.cmake +++ b/MLIR.cmake @@ -34,7 +34,7 @@ set(MLIR_SRC_INCLUDE_PATH ${LLVM_SRC}/projects/mlir/include) set(MLIR_BIN_INCLUDE_PATH ${LLVM_BUILD}/projects/mlir/include) set(MLIR_TOOLS_DIR ${LLVM_BUILD}/bin) -set(ONNF_TOOLS_DIR ${ONNF_BIN_ROOT}/src/compiler/tool/onnf_opt) +set(ONNF_TOOLS_DIR ${ONNF_BIN_ROOT}/bin) set(ONNF_LIT_TEST_SRC_DIR ${CMAKE_SOURCE_DIR}/test/mlir) set(ONNF_LIT_TEST_BUILD_DIR ${CMAKE_BINARY_DIR}/test/mlir) diff --git a/src/builder/frontend_dialect_transformer.cpp b/src/builder/frontend_dialect_transformer.cpp index 3f5d8d9..05c079a 100644 --- a/src/builder/frontend_dialect_transformer.cpp +++ b/src/builder/frontend_dialect_transformer.cpp @@ -614,7 +614,6 @@ private: onnx::NodeProto node, int nIn, int nOut, std::initializer_list> attrs) { - // Conv has attribute dilations, kernel_shape, pads, the default value of // which is determined by the shape of first argument. However, since the // shape is unknown now, these attributes can be not generated auto @@ -686,7 +685,7 @@ private: } void ImportGraph(const onnx::GraphProto &graph, - const std::string &name = "main") { + const std::string &name = "main_graph") { // create a function for the graph // TODO: // * get name and type for the function. @@ -699,13 +698,18 @@ private: } // TODO: import the initializer - auto func_type = builder_.getFunctionType(arg_types, {}); - auto main_func = - mlir::FuncOp::create(UnknownLoc(), name, func_type, /* attrs = */ {}); - auto &entryBlock = *main_func.addEntryBlock(); + auto funcType = builder_.getFunctionType(arg_types, {}); + auto mainFunc = + mlir::FuncOp::create(UnknownLoc(), name, funcType, /* attrs = */ {}); + auto entryPoint = mlir::ONNXEntryPointOp::create( + UnknownLoc(), mainFunc, /*numInputs=*/graph.input().size(), + /*numOutputs=*/graph.output().size()); + auto &entryBlock = *mainFunc.addEntryBlock(); builder_.setInsertionPointToStart(&entryBlock); - module_.push_back(main_func); + + module_.push_back(mainFunc); + module_.push_back(entryPoint); for (auto it : llvm::zip(graph.input(), entryBlock.getArguments())) { ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it)); @@ -728,8 +732,8 @@ private: builder_.create(UnknownLoc(), ret_vals); // Update main function signature to reflect types of newly imported // output tensors. - func_type = builder_.getFunctionType(arg_types, ret_types); - main_func.setType(func_type); + funcType = builder_.getFunctionType(arg_types, ret_types); + mainFunc.setType(funcType); } }; // FrontendGenImpl class } // namespace diff --git a/src/builder/frontend_dialect_transformer.hpp b/src/builder/frontend_dialect_transformer.hpp index f12512c..8544087 100644 --- a/src/builder/frontend_dialect_transformer.hpp +++ b/src/builder/frontend_dialect_transformer.hpp @@ -21,7 +21,7 @@ namespace mlir { class MLIRContext; class OwningModuleRef; -} // namespace mlir +} // namespace mlir //===----------------------------------------------------------------------===// // Import a model into one of ONNF's frontend models. @@ -41,7 +41,8 @@ mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model); * @return MLIR::module generated for the ONNX model. */ void ImportFrontendModelFile(std::string model_fname, - mlir::MLIRContext& context, mlir::OwningModuleRef& module); + mlir::MLIRContext &context, + mlir::OwningModuleRef &module); /*! * TODO: Import models into other extension dialects that cover the diff --git a/src/compiler/analysis/CMakeLists.txt b/src/compiler/analysis/CMakeLists.txt new file mode 100644 index 0000000..01a1fcb --- /dev/null +++ b/src/compiler/analysis/CMakeLists.txt @@ -0,0 +1,5 @@ +add_library(DLCAnalysis STATIC + extract_integer_set.cpp) + +target_include_directories(DLCAnalysis PRIVATE ${DLC_SRC_ROOT}) +target_include_directories(DLCAnalysis PRIVATE ${DLC_BIN_ROOT}) \ No newline at end of file diff --git a/src/compiler/dialect/krnl/krnl_ops.cpp b/src/compiler/dialect/krnl/krnl_ops.cpp index 6454936..a436d7c 100644 --- a/src/compiler/dialect/krnl/krnl_ops.cpp +++ b/src/compiler/dialect/krnl/krnl_ops.cpp @@ -364,6 +364,14 @@ ParseResult parseKrnlReturnLoopsOp(OpAsmParser &parser, return success(); } +void KrnlEntryPointOp::build(mlir::Builder *builder, OperationState &state, + SymbolRefAttr funcAttr, IntegerAttr numInputs, + IntegerAttr numOutputs) { + state.addAttribute(KrnlEntryPointOp::getEntryPointFuncAttrName(), funcAttr); + state.addAttribute(KrnlEntryPointOp::getNumInputsAttrName(), numInputs); + state.addAttribute(KrnlEntryPointOp::getNumOutputsAttrName(), numOutputs); +} + #define GET_OP_CLASSES #include "src/compiler/krnl.cpp.inc" } // namespace mlir diff --git a/src/compiler/dialect/krnl/krnl_ops.hpp b/src/compiler/dialect/krnl/krnl_ops.hpp index 4b9fe4e..b51bda0 100644 --- a/src/compiler/dialect/krnl/krnl_ops.hpp +++ b/src/compiler/dialect/krnl/krnl_ops.hpp @@ -11,6 +11,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Function.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/StandardTypes.h" diff --git a/src/compiler/dialect/krnl/krnl_ops.td b/src/compiler/dialect/krnl/krnl_ops.td index c410c70..bb894f7 100644 --- a/src/compiler/dialect/krnl/krnl_ops.td +++ b/src/compiler/dialect/krnl/krnl_ops.td @@ -8,35 +8,27 @@ include "mlir/IR/OpBase.td" - def Krnl_Dialect : Dialect { let name = "krnl"; let cppNamespace = ""; } // Require regions to have krnl.terminate terminator operation. -def ImplicitKrnlTerminator - : SingleBlockImplicitTerminator<"KrnlTerminatorOp">; +def ImplicitKrnlTerminator : SingleBlockImplicitTerminator<"KrnlTerminatorOp">; def KrnlDefineLoopsOp : Op { let summary = "define_loops operation"; let description = [{ - The "krnl.define_loops" operation is used to define input loops, those are the for loops appearing in the input program that we intend to optimize. - }]; let arguments = (ins); let results = (outs Variadic); - let skipDefaultBuilders = 1; - - let builders = [ - OpBuilder<"Builder *builder, OperationState &result," - "int64_t num_loops"> - ]; + let builders = [ OpBuilder<"Builder *builder, OperationState &result," + "int64_t num_loops"> ]; let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; @@ -44,33 +36,27 @@ def KrnlDefineLoopsOp : Op { let extraClassDeclaration = [{ static StringRef getNumLoopsAttrName() { return "num_loops"; } - // Helper function to extract the number of loops being defined. - int64_t getNumLoops() { - auto num_loops = - getAttrOfType( - getNumLoopsAttrName()) - .getValue() - .getSExtValue(); - return num_loops; - } - }]; - - + // Helper function to extract the number of loops being defined. + int64_t getNumLoops() { + auto num_loops = getAttrOfType(getNumLoopsAttrName()) + .getValue() + .getSExtValue(); + return num_loops; + } +}]; } def KrnlOptimizeLoopsOp : Op { let summary = "optimize_loops operation"; let description = [{ - The "krnl.optimize_loops" operation is essentially a cosmetic operation - which exists to encapsulate a region where loops are being scheduled/optimized. + which exists to encapsulate a region where loops are being scheduled / + optimized. - The optimized loops are returned at the end of the - region associated with the krnl.optimize_loops operation. - - For example: - TBD once we have actual schedule intrinsics. + The optimized loops are returned at the end of the region associated with + the krnl.optimize_loops operation. + For example : TBD once we have actual schedule intrinsics. }]; let arguments = (ins Variadic); @@ -79,10 +65,8 @@ def KrnlOptimizeLoopsOp : Op { let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<"Builder *builder, OperationState &result, " - "int timestamp_space_rank"> - ]; + let builders = [ OpBuilder<"Builder *builder, OperationState &result, " + "int timestamp_space_rank"> ]; let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; @@ -91,7 +75,6 @@ def KrnlOptimizeLoopsOp : Op { def KrnlIterateOp : Op { let summary = "iterate operation"; let description = [{ - The "krnl.iterate" operation is conceptually equivalent to a nested for loops. For instance, say we have the following two @@ -103,25 +86,20 @@ def KrnlIterateOp : Op { Then, consider the following krnl.iterate operation: krnl.iterate (%o0, %o1) with (%l0 -> %i0 = 0 to 10, %l1 -> %i1 = 0 to 10) { - // Some operations. + // Some operations. } It is equivalent to: - for (i0=0; i0<10; i0++) - for (i1=0; i1<10; i1++) + for (i0 = 0; i0 < 10; i0++) + for (i1 = 0; i1 < 10; i1++) // Some operations. }]; let arguments = (ins Variadic); - let regions = (region SizedRegion<1>:$bodyRegion); - let skipDefaultBuilders = 1; - - let builders = [ - OpBuilder<"Builder *builder, OperationState &result, " - "KrnlIterateOperandPack operandPack"> - ]; + let builders = [ OpBuilder<"Builder *builder, OperationState &result, " + "KrnlIterateOperandPack operandPack"> ]; let extraClassDeclaration = [{ // In krnl.iterate operation, operands are stored as such @@ -134,20 +112,19 @@ def KrnlIterateOp : Op { int64_t getNumOptimizedLoops() { auto num_optimized_loops = - getAttrOfType( - getNumOptimizedLoopsAttrName()) - .getValue() - .getSExtValue(); + getAttrOfType(getNumOptimizedLoopsAttrName()) + .getValue() + .getSExtValue(); return num_optimized_loops; } // Get name of the attribute for storing bound represented using affine maps. - static StringRef getBoundsAttrName() { return "bounds"; } + static StringRef getBoundsAttrName() { return "bounds"; } }]; - let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parse$cppClass(parser, result); }]; - let verifier = [{ return ::verify(*this); }]; + let printer = [{ return ::print(p, *this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; + let verifier = [{ return ::verify(*this); }]; } def KrnlReturnLoopsOp : Op { @@ -182,11 +159,30 @@ def KrnlTerminatorOp : Op { let verifier = ?; } +def KrnlEntryPointOp : Op { + let summary = "Indicate ONNX entry point"; + let description = [{The "krnl.entry_point" function indicates the main entry + point of ONNX model.}]; + let builders = [ OpBuilder<"Builder *builder, OperationState &result, " + "SymbolRefAttr funcAttr, IntegerAttr numInputs, " + "IntegerAttr numOutputs"> ]; + + let extraClassDeclaration = [{ + static StringRef getEntryPointFuncAttrName() { return "func"; } + static StringRef getNumInputsAttrName() { return "numInputs"; } + static StringRef getNumOutputsAttrName() { return "numOutputs"; } + }]; + + // No custom parsing/printing form. + let parser = ?; + let printer = ?; +} + def KrnlMemcpyOp : Op { let summary = "Krnl memcpy operation"; let description = [{ - In the KRNL dialect the reshape op doesn't generate a new memory entry and - treats a reshape like a cast. + In the KRNL dialect the reshape op + doesn't generate a new memory entry and treats a reshape like a cast. }]; let arguments = (ins AnyMemRef:$dest, AnyMemRef:$src, AnyInteger:$size); diff --git a/src/compiler/dialect/onnx/onnx.td b/src/compiler/dialect/onnx/onnx.td index 4244d58..87901ce 100644 --- a/src/compiler/dialect/onnx/onnx.td +++ b/src/compiler/dialect/onnx/onnx.td @@ -58,6 +58,27 @@ class ONNX_Op traits = []> : include "dialect/onnx/onnxop.inc" +// Indicate entry point functions of ONNX graph. +def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> { + let summary = "Indicate ONNX entry point"; + let description = [{ + The "onnx.EntryPoint" function indicates the main entry point of ONNX model. + }]; + + let builders = [OpBuilder<[{Builder *builder, OperationState &state, + FuncOp function, int numInputs, int numOutputs}]>]; + + let extraClassDeclaration = [{ + static ONNXEntryPointOp create(Location location, FuncOp& func, + int numInputs, int numOutputs); + + static StringRef getEntryPointFuncAttrName() { return "func"; } + static StringRef getNumInputsAttrName() { return "numInputs"; } + static StringRef getNumOutputsAttrName() { return "numOutputs"; } + }]; +} + + def ONNXFullGemmOp: ONNX_Op<"FullGemm", [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX general matrix multiply operation"; diff --git a/src/compiler/dialect/onnx/onnx_ops.cpp b/src/compiler/dialect/onnx/onnx_ops.cpp index da50632..5b60279 100644 --- a/src/compiler/dialect/onnx/onnx_ops.cpp +++ b/src/compiler/dialect/onnx/onnx_ops.cpp @@ -7,7 +7,6 @@ // This file defines ONNX operations in the MLIR operation set. // //===----------------------------------------------------------------------===// - #include "mlir/Dialect/Traits.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" @@ -38,6 +37,28 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx) >(); } +void ONNXEntryPointOp::build(mlir::Builder *builder, + mlir::OperationState &state, mlir::FuncOp function, + int numInputs, int numOutputs) { + state.addAttribute(ONNXEntryPointOp::getEntryPointFuncAttrName(), + builder->getSymbolRefAttr(function)); + state.addAttribute(ONNXEntryPointOp::getNumInputsAttrName(), + builder->getI32IntegerAttr(numInputs)); + state.addAttribute(ONNXEntryPointOp::getNumOutputsAttrName(), + builder->getI32IntegerAttr(numOutputs)); +} + +ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location, + mlir::FuncOp &func, int numInputs, + int numOutputs) { + mlir::OperationState state(location, "onnx.EntryPoint"); + Builder builder(location->getContext()); + mlir::ONNXEntryPointOp::build(&builder, state, func, numInputs, numOutputs); + Operation *op = mlir::Operation::create(state); + auto onnxEntryOp = llvm::cast(op); + return onnxEntryOp; +} + //===----------------------------------------------------------------------===// // ONNX Operations //===----------------------------------------------------------------------===// diff --git a/src/compiler/dialect/onnx/onnx_ops.hpp b/src/compiler/dialect/onnx/onnx_ops.hpp index deab78a..9903b62 100644 --- a/src/compiler/dialect/onnx/onnx_ops.hpp +++ b/src/compiler/dialect/onnx/onnx_ops.hpp @@ -10,6 +10,7 @@ #pragma once +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/src/compiler/pass/lower_frontend_to_krnl.cpp b/src/compiler/pass/lower_frontend_to_krnl.cpp index c1ca650..b94a269 100644 --- a/src/compiler/pass/lower_frontend_to_krnl.cpp +++ b/src/compiler/pass/lower_frontend_to_krnl.cpp @@ -8,7 +8,6 @@ // Krnl IR and standard operations. // //===----------------------------------------------------------------------===// - #include #include "mlir/Dialect/AffineOps/AffineOps.h" @@ -884,6 +883,27 @@ struct ONNXReshapeOpLowering : public ConversionPattern { } }; +//===----------------------------------------------------------------------===// +// EntryPoint Op lowering to Krnl Entry Point. +//===----------------------------------------------------------------------===// + +class ONNXEntryPointLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(ONNXEntryPointOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, + op.getAttrOfType( + ONNXEntryPointOp::getEntryPointFuncAttrName()), + op.getAttrOfType(ONNXEntryPointOp::getNumInputsAttrName()), + op.getAttrOfType( + ONNXEntryPointOp::getNumOutputsAttrName())); + return matchSuccess(); + } +}; + //===----------------------------------------------------------------------===// // Conversion from Tensor type to the Standard dialect MemRef type. //===----------------------------------------------------------------------===// @@ -985,7 +1005,7 @@ void FrontendToKrnlLoweringPass::runOnModule() { ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, - ONNXReshapeOpLowering>(&getContext()); + ONNXReshapeOpLowering, ONNXEntryPointLowering>(&getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` diff --git a/src/compiler/transform/lower_krnl.cpp b/src/compiler/transform/lower_krnl.cpp index 36f8f56..6e53038 100644 --- a/src/compiler/transform/lower_krnl.cpp +++ b/src/compiler/transform/lower_krnl.cpp @@ -143,14 +143,16 @@ void KrnlToAffineLoweringPass::runOnFunction() { // We expect IR to be free of Krnl Dialect Ops. target.addIllegalDialect(); target.addLegalOp(); + target.addLegalOp(); OwningRewritePatternList patterns; patterns.insert( &getContext()); - if (failed(applyPartialConversion(getFunction(), target, patterns))) + if (failed(applyPartialConversion(getFunction(), target, patterns))) { signalPassFailure(); + } } } // namespace diff --git a/src/compiler/transform/lower_to_llvm.cpp b/src/compiler/transform/lower_to_llvm.cpp index 952635b..fec2e03 100644 --- a/src/compiler/transform/lower_to_llvm.cpp +++ b/src/compiler/transform/lower_to_llvm.cpp @@ -4,7 +4,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/ADT/Sequence.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" @@ -15,6 +14,7 @@ #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/Sequence.h" #include "src/compiler/dialect/krnl/krnl_ops.hpp" #include "src/compiler/pass/passes.hpp" @@ -23,20 +23,39 @@ using namespace mlir; namespace { +static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName, + ModuleOp module, + mlir::LLVM::LLVMType funcType, + PatternRewriter &rewriter) { + auto *context = module.getContext(); + if (module.lookupSymbol(funcName)) { + auto symbolRef = SymbolRefAttr::get(funcName, context); + assert(symbolRef.getType() == funcType && "wrong symbol type"); + return symbolRef; + } + + // Insert the function into the body of the parent module. + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create(module.getLoc(), funcName, funcType); + return SymbolRefAttr::get(funcName, context); +} + //===----------------------------------------------------------------------===// -// KRNL to LLVM: patterns which need a direct lowering to LLVM. +// KRNL to LLVM: KrnlMemcpyOpLowering //===----------------------------------------------------------------------===// class KrnlMemcpyOpLowering : public ConversionPattern { - public: - explicit KrnlMemcpyOpLowering(MLIRContext* context) +public: + explicit KrnlMemcpyOpLowering(MLIRContext *context) : ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {} - PatternMatchResult matchAndRewrite(Operation* op, ArrayRef operands, - ConversionPatternRewriter& rewriter) const override { - auto* context = op->getContext(); + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto *context = op->getContext(); auto loc = op->getLoc(); - auto* llvmDialect = + auto *llvmDialect = op->getContext()->getRegisteredDialect(); assert(llvmDialect && "expected llvm dialect to be registered"); @@ -47,39 +66,40 @@ class KrnlMemcpyOpLowering : public ConversionPattern { // First operand. Type dstType = operands[0]->getType().cast().getStructElementType(1); - Value* alignedDstMemory = rewriter.create( + Value *alignedDstMemory = rewriter.create( loc, dstType, operands[0], rewriter.getI64ArrayAttr(1)); - Value* alignedInt8PtrDstMemory = rewriter.create( + Value *alignedInt8PtrDstMemory = rewriter.create( loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory); // Second operand. Type srcType = operands[1]->getType().cast().getStructElementType(1); - Value* alignedSrcMemory = rewriter.create( + Value *alignedSrcMemory = rewriter.create( loc, srcType, operands[1], rewriter.getI64ArrayAttr(1)); - Value* alignedInt8PtrSrcMemory = rewriter.create( + Value *alignedInt8PtrSrcMemory = rewriter.create( loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory); // Size. - Value* int64Size = rewriter.create( + Value *int64Size = rewriter.create( loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]); // Memcpy call - rewriter.create(loc, memcpyRef, - LLVM::LLVMType::getVoidTy(llvmDialect), - ArrayRef( + rewriter.create( + loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect), + ArrayRef( {alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size})); rewriter.eraseOp(op); return matchSuccess(); } - private: +private: /// Return a symbol reference to the memcpy function, inserting it into the /// module if necessary. - static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter& rewriter, - ModuleOp module, LLVM::LLVMDialect* llvmDialect) { - auto* context = module.getContext(); + static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter, + ModuleOp module, + LLVM::LLVMDialect *llvmDialect) { + auto *context = module.getContext(); if (module.lookupSymbol("llvm.memcpy.p0i8.p0i8.i64")) return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context); // Create a function declaration for memcpy, the signature is: @@ -87,19 +107,355 @@ class KrnlMemcpyOpLowering : public ConversionPattern { auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect); auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); - auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy, + auto llvmFnType = LLVM::LLVMType::getFunctionTy( + llvmVoidTy, ArrayRef({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty}), false); // Insert the memcpy function into the body of the parent module. PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); - rewriter.create( - module.getLoc(), "llvm.memcpy.p0i8.p0i8.i64", llvmFnType); + rewriter.create(module.getLoc(), + "llvm.memcpy.p0i8.p0i8.i64", llvmFnType); return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context); } }; -} // end namespace + +//===----------------------------------------------------------------------===// +// KRNL to LLVM: KrnlEntryPointOp +//===----------------------------------------------------------------------===// + +class KrnlEntryPointOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + enum class API { + CREATE_ORDERED_DYN_MEM_REF_DICT, + CREATE_DYN_MEM_REF, + GET_DYN_MEM_REF, + SET_DYN_MEM_REF, + GET_DATA, + SET_DATA, + GET_SIZES, + GET_STRIDES, + }; + + struct ApiSpec { + API id; + std::string name; + FlatSymbolRefAttr symbolRef; + LLVM::LLVMType outputTy; + SmallVector inputTys; + + ApiSpec(API id, const std::string &name, LLVM::LLVMType outputTy, + ArrayRef inputTys) + : id(id), name(name), outputTy(outputTy), + inputTys(inputTys.begin(), inputTys.end()) {} + + LLVM::LLVMType funcTy() { + return LLVM::LLVMType::getFunctionTy(outputTy, inputTys, + /*isVarArg=*/false); + } + }; + + PatternMatchResult matchAndRewrite(KrnlEntryPointOp op, + PatternRewriter &rewriter) const override { + + auto *llvmDialect = + op.getContext()->getRegisteredDialect(); + assert(llvmDialect && "expected llvm dialect to be registered"); + auto module = op.getParentOfType(); + auto apiRegistry = RegisterAllApis(module, rewriter, llvmDialect); + auto loc = op.getLoc(); + auto numOutputs = + op.getAttrOfType(KrnlEntryPointOp::getNumOutputsAttrName()) + .getInt(); + + using LLVMType = LLVM::LLVMType; + auto opaquePtrTy = LLVMType::getInt8PtrTy(llvmDialect); + auto int32Ty = LLVMType::getInt32Ty(llvmDialect); + + // Rewrite Krnl Entry Point Operation to an LLVM function with a dynamic + // signature. The signature is dynamic because it remains the same no matter + // what the model input/output schema look like. Such dynamic signature + // takes a opaque ptr as input, representing a ptr to a data structure + // containing a set of dynamic memrefs wrapped in a vector; similarly the + // output is also a opaque ptr to a data structure with output memrefs + // wrapped within it. + auto staticEntryPointFuncName = + op.getAttrOfType( + KrnlEntryPointOp::getEntryPointFuncAttrName()) + .getLeafReference(); + auto dynEntryPointName = "_dyn_entry_point_" + staticEntryPointFuncName; + assert(module.lookupSymbol(dynEntryPointName.str()) == nullptr && + "dynamic entry point name is not unique"); + rewriter.eraseOp(op); + auto dynEntryPointFuncTy = + LLVMType::getFunctionTy(opaquePtrTy, {opaquePtrTy}, false); + auto dynamicEntryPointFunc = rewriter.create( + loc, dynEntryPointName.str(), dynEntryPointFuncTy); + auto &entryPointEntryBlock = + createEntryBlock(dynEntryPointFuncTy, dynamicEntryPointFunc); + rewriter.setInsertionPointToStart(&entryPointEntryBlock); + + // Based on the static entry point type signature, unpack dynamic memory + // refs to corresponding static memory refs. + auto *staticEntryPointFunc = module.lookupSymbol(staticEntryPointFuncName); + assert(staticEntryPointFunc && + isa(staticEntryPointFunc) && + "entry point func must exist and be an llvm func op"); + auto staticEntryPointTy = dyn_cast(staticEntryPointFunc) + .getType() + .dyn_cast(); + + // Retrieve dynamic mem refs from wrapped input, and convert every one of + // them to static mem refs. + SmallVector staticInputs; + auto wrappedInput = entryPointEntryBlock.getArgument(0); + for (size_t i = 0; i < staticEntryPointTy.getFunctionNumParams(); i++) { + // Call API function to retrieve the i-th dynamic memref. + auto idxVal = rewriter.create( + loc, int32Ty, rewriter.getI32IntegerAttr(i)); + auto dynMemRef = callApi(rewriter, loc, apiRegistry, API::GET_DYN_MEM_REF, + {wrappedInput, idxVal}); + + // Create a (static) memref type corresponding to the i-th memref input to + // the inference function on stack, and load it to memRef. + auto memRefPtrTy = staticEntryPointTy.getFunctionParamType(i); + auto memRefTy = memRefPtrTy.getPointerElementTy(); + auto one = rewriter.create( + loc, int32Ty, rewriter.getI32IntegerAttr(1)); + Value *ptrToMemRef = + rewriter.create(loc, memRefPtrTy, one, + /*alignment=*/0); + + // Fill in the memref underlying ptrToMemRef with information extracted + // from dynMemRef. + fillPtrToMemRefWithDynMemRef(*dynMemRef, *ptrToMemRef, rewriter, loc, + apiRegistry, llvmDialect); + + // ptrToMemRef will be an input to main computation graph function. + staticInputs.emplace_back(ptrToMemRef); + } + + // If more than one output exists, the struct becomes a nested struct, + // the unpacking logic can be more involved, so no support for now. + assert(numOutputs == 1 && "only support 1 output tensor now."); + + // Call static entry point with the memref ptrs created, and get output. + auto outputMemRefs = rewriter.create( + loc, staticEntryPointTy.getFunctionResultType(), + rewriter.getSymbolRefAttr(staticEntryPointFuncName), staticInputs); + + // Create wrapped output. + auto wrappedOutput = callApi(rewriter, loc, apiRegistry, + API::CREATE_ORDERED_DYN_MEM_REF_DICT, {}); + + // Get the first memref returned, convert to a dynamic memref and store + // it in the wrapped Output. + auto outMemRef = outputMemRefs.getResult(0); + auto outMemRefTy = outMemRef->getType().dyn_cast(); + auto outMemRefRank = + outMemRefTy.getStructElementType(3).getArrayNumElements(); + auto outMemRefRankVal = rewriter.create( + loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank)); + auto outDynMemRef = callApi(rewriter, loc, apiRegistry, + API::CREATE_DYN_MEM_REF, {outMemRefRankVal}); + fillDynMemRefWithMemRef(*outMemRef, *outDynMemRef, rewriter, loc, + apiRegistry, llvmDialect); + auto zero = rewriter.create( + loc, int32Ty, rewriter.getI32IntegerAttr(0)); + callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF, + {wrappedOutput, zero, outDynMemRef}); + + // Return wrapped output. + rewriter.create(loc, + SmallVector({wrappedOutput})); + return matchSuccess(); + } + +private: + using ApiRegistry = std::map; + + ApiRegistry RegisterAllApis(ModuleOp &module, PatternRewriter &rewriter, + LLVM::LLVMDialect *llvmDialect) const { + using LLVMType = LLVM::LLVMType; + auto voidTy = LLVMType::getVoidTy(llvmDialect); + auto opaquePtrTy = LLVMType::getInt8PtrTy(llvmDialect); + auto int32Ty = LLVMType::getInt32Ty(llvmDialect); + auto int64Ty = LLVMType::getInt64Ty(llvmDialect); + auto int64PtrTy = int64Ty.getPointerTo(); + + // Declare API type as an enum value, its string name and an LLVM Type + // specifying its signature. + // clang-format off + std::vector apiSpecs = { + ApiSpec(API::CREATE_ORDERED_DYN_MEM_REF_DICT, "createOrderedDynMemRefDict", opaquePtrTy, {}), + ApiSpec(API::CREATE_DYN_MEM_REF, "createDynMemRef", opaquePtrTy, {int32Ty}), + ApiSpec(API::GET_DATA, "getData", opaquePtrTy, {opaquePtrTy}), + ApiSpec(API::SET_DATA, "setData", voidTy, {opaquePtrTy, opaquePtrTy}), + ApiSpec(API::GET_DYN_MEM_REF, "getDynMemRef", opaquePtrTy, {opaquePtrTy, int32Ty}), + ApiSpec(API::SET_DYN_MEM_REF, "setDynMemRef", voidTy, {opaquePtrTy, int32Ty, opaquePtrTy}), + ApiSpec(API::GET_SIZES, "getSizes", int64PtrTy, {opaquePtrTy}), + ApiSpec(API::GET_STRIDES, "getStrides", int64PtrTy, {opaquePtrTy}) + }; + // clang-format on + + // Declare APIs in the current module and build an API registry mapping api + // identities to a symbol reference to the API function. + ApiRegistry registry; + for (auto &apiSpec : apiSpecs) { + apiSpec.symbolRef = getOrInsertExternFunc(apiSpec.name, module, + apiSpec.funcTy(), rewriter); + registry.emplace(apiSpec.id, apiSpec); + } + + return registry; + } + + // Call a registered API, return the return SSA values if only one result is + // returned, otherwise return nullptr. + Value *callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry, + API apiId, ArrayRef params) const { + auto returnVals = rewriter.create( + loc, registry.at(apiId).outputTy, registry.at(apiId).symbolRef, + ArrayRef(params)); + if (returnVals.getNumResults() == 1) + return returnVals.getResult(0); + return nullptr; + } + + // Helper function to insert an entry block to LLVM function. + // (TODO): upstream this to MLIR. + Block &createEntryBlock(LLVM::LLVMType &dynEntryPointFuncType, + LLVM::LLVMFuncOp &dynamicEntryPointFunc) const { + // Add entry block: + auto *entryPointEntryBlock = new Block(); + dynamicEntryPointFunc.push_back(entryPointEntryBlock); + llvm::SmallVector argTypes; + for (size_t i = 0; i < dynEntryPointFuncType.getFunctionNumParams(); i++) + argTypes.emplace_back(dynEntryPointFuncType.getFunctionParamType(i)); + entryPointEntryBlock->addArguments(argTypes); + return *entryPointEntryBlock; + } + + void fillPtrToMemRefWithDynMemRef(Value &dynMemRef, Value &ptrToMemRef, + PatternRewriter &rewriter, + const Location &loc, + const std::map &apiRegistry, + LLVM::LLVMDialect *llvmDialect) const { + auto memRefPtrTy = ptrToMemRef.getType().dyn_cast(); + auto memRefTy = memRefPtrTy.getPointerElementTy(); + auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); + + Value *memRef = + rewriter.create(loc, memRefPtrTy, &ptrToMemRef); + + // Set dataPtr and alignedDataPtr; + auto dataPtr = + callApi(rewriter, loc, apiRegistry, API::GET_DATA, {&dynMemRef}); + dataPtr = rewriter.create( + loc, memRefTy.getStructElementType(0), dataPtr); + memRef = rewriter.create( + loc, memRefTy, memRef, dataPtr, + rewriter.getArrayAttr({rewriter.getI32IntegerAttr(0)})); + memRef = rewriter.create( + loc, memRefTy, memRef, dataPtr, + rewriter.getArrayAttr({rewriter.getI32IntegerAttr(1)})); + + // Use zero offset now. + auto zero = rewriter.create( + loc, int64Ty, rewriter.getI64IntegerAttr(0)); + memRef = rewriter.create( + loc, memRefTy, memRef, zero, + rewriter.getArrayAttr({rewriter.getI32IntegerAttr(2)})); + + // Get rank, sizes array ptr and strides array ptr. + auto rank = memRefTy.getStructElementType(3).getArrayNumElements(); + auto sizesArrayPtr = + callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {&dynMemRef}); + auto stridesArrayPtr = + callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {&dynMemRef}); + + for (decltype(rank) i = 0; i < rank; i++) { + auto dimIdx = rewriter.create( + loc, int64Ty, rewriter.getI64IntegerAttr(i)); + + // Insert size of the dimension. + auto dimSizePtr = rewriter.create( + loc, int64Ty.getPointerTo(), sizesArrayPtr, + ArrayRef({dimIdx})); + auto dimSize = rewriter.create(loc, int64Ty.getPointerTo(), + dimSizePtr); + memRef = rewriter.create( + loc, memRefTy, memRef, dimSize, + rewriter.getArrayAttr( + {rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)})); + + // Insert stride of the dimension. + auto dimStridePtr = rewriter.create( + loc, int64Ty.getPointerTo(), sizesArrayPtr, + ArrayRef({dimIdx})); + auto dimStride = rewriter.create( + loc, int64Ty.getPointerTo(), dimStridePtr); + memRef = rewriter.create( + loc, memRefTy, memRef, dimStride, + rewriter.getArrayAttr( + {rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)})); + } + + rewriter.create(loc, memRef, &ptrToMemRef); + } + + void fillDynMemRefWithMemRef(Value &outMemRef, Value &outDynMemRef, + PatternRewriter &rewriter, const Location &loc, + const std::map &apiRegistry, + LLVM::LLVMDialect *llvmDialect) const { + auto outMemRefTy = outMemRef.getType().dyn_cast(); + auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); + + // Extract the data pointer, and record it in dynamic mem ref created. + Value *outMemRefDataPtr = rewriter.create( + loc, outMemRefTy.getStructElementType(0), &outMemRef, + rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)})); + outMemRefDataPtr = rewriter.create( + loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr); + callApi(rewriter, loc, apiRegistry, API::SET_DATA, + {&outDynMemRef, outMemRefDataPtr}); + + auto rank = outMemRefTy.getStructElementType(3).getArrayNumElements(); + auto sizesArrayPtr = + callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {&outDynMemRef}); + auto stridesArrayPtr = + callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {&outDynMemRef}); + + for (decltype(rank) i = 0; i < rank; i++) { + auto dimIdx = rewriter.create( + loc, int64Ty, rewriter.getI64IntegerAttr(i)); + + // Transfer size of dimension from memref to dynamic memref. + auto dimSize = rewriter.create( + loc, int64Ty, &outMemRef, + rewriter.getArrayAttr( + {rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)})); + auto dimSizePtr = rewriter.create( + loc, int64Ty.getPointerTo(), sizesArrayPtr, + ArrayRef({dimIdx})); + rewriter.create(loc, dimSize, dimSizePtr); + + // Transfer stride of dimension from memref to dynamic memref. + auto dimStride = rewriter.create( + loc, int64Ty, &outMemRef, + rewriter.getArrayAttr( + {rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)})); + auto dimStridePtr = rewriter.create( + loc, int64Ty.getPointerTo(), stridesArrayPtr, + ArrayRef({dimIdx})); + rewriter.create(loc, dimStride, dimStridePtr); + } + } +}; +} // end namespace //===----------------------------------------------------------------------===// // KRNL + Stadard + Affine dialects lowering to LLVM. @@ -109,7 +465,7 @@ namespace { struct KrnlToLLVMLoweringPass : public ModulePass { void runOnModule() final; }; -} // end anonymous namespace +} // end anonymous namespace void KrnlToLLVMLoweringPass::runOnModule() { // Define the target for this lowering i.e. the LLVM dialect. @@ -128,12 +484,13 @@ void KrnlToLLVMLoweringPass::runOnModule() { populateStdToLLVMConversionPatterns(typeConverter, patterns); // Lower from the `krnl` dialect i.e. the Reshape operation. - patterns.insert(&getContext()); + patterns.insert( + &getContext()); // We want to completely lower to LLVM, so we use a `FullConversion`. This // ensures that only legal operations will remain after the conversion. - auto module = getModule(); - if (failed(applyFullConversion(module, target, patterns, &typeConverter))) + if (failed( + applyFullConversion(getModule(), target, patterns, &typeConverter))) signalPassFailure(); } @@ -142,5 +499,5 @@ std::unique_ptr mlir::createKrnlLowerToLLVMPass() { return std::make_unique(); } -static PassRegistration pass( - "lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM."); +static PassRegistration + pass("lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM."); diff --git a/src/main.cpp b/src/main.cpp index e3896ba..83d7a56 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -7,10 +7,7 @@ //===----------------------------------------------------------------------===// #include -#include #include -#include -#include #include "llvm/Bitcode/BitcodeWriter.h" #include "llvm/Support/CommandLine.h" @@ -24,9 +21,7 @@ #include "src/compiler/dialect/onnx/onnx_ops.hpp" #include "src/compiler/pass/passes.hpp" -#include "mlir/Analysis/Verifier.h" #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/IR/MLIRContext.h" diff --git a/src/runtime/CMakeLists.txt b/src/runtime/CMakeLists.txt new file mode 100644 index 0000000..8522310 --- /dev/null +++ b/src/runtime/CMakeLists.txt @@ -0,0 +1,17 @@ +add_library(cruntime + dyn_memref.cpp + dyn_memref.h + data_type.h) +target_include_directories(cruntime + PRIVATE ${DLC_SRC_ROOT} ${DLC_BIN_ROOT} + ${DLC_SRC_ROOT}) + +pybind11_add_module(pyruntime + dyn_memref.cpp + dyn_memref.h + runtime.cpp + runtime.hpp) +target_link_libraries(pyruntime PRIVATE ${CMAKE_DL_LIBS}) +target_include_directories(pyruntime + PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT} + ${ONNF_SRC_ROOT}) diff --git a/src/runtime/data_type.h b/src/runtime/data_type.h new file mode 100644 index 0000000..9631318 --- /dev/null +++ b/src/runtime/data_type.h @@ -0,0 +1,30 @@ +enum DYN_MEMREF_DATA_TYPE { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + + // IEEE754 half-precision floating-point format (16 bits wide). + // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. + FLOAT16 = 10; + + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + + // Non-IEEE floating-point format based on IEEE754 single-precision + // floating-point number truncated to 16 bits. + // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16; + + // Future extensions go here. +}; \ No newline at end of file diff --git a/src/runtime/dyn_memref.cpp b/src/runtime/dyn_memref.cpp new file mode 100644 index 0000000..a20001c --- /dev/null +++ b/src/runtime/dyn_memref.cpp @@ -0,0 +1,74 @@ +#include +#include +#include +#include + +#include "dyn_memref.h" + +DynMemRef::DynMemRef(int _rank) { + rank = _rank; + sizes = (INDEX_TYPE *)malloc(rank * sizeof(INDEX_TYPE)); + strides = (int64_t *)malloc(rank * sizeof(int64_t)); +} + +// An ordered dynamic MemRef dictionary. +// The goal is to support accessing dynamic memory ref by name and by index. +// Currently, only accessing by index is supported. +struct OrderedDynMemRefDict { + std::map tensorDict; + std::vector orderedNames; +}; + +int numDynMemRefs(OrderedDynMemRefDict *dict) { + return dict->orderedNames.size(); +} + +OrderedDynMemRefDict *createOrderedDynMemRefDict() { + return new OrderedDynMemRefDict(); +} + +DynMemRef *createDynMemRef(int rank) { return new DynMemRef(rank); } + +DynMemRef *getDynMemRef(OrderedDynMemRefDict *tensorDict, int idx) { + return tensorDict->tensorDict[tensorDict->orderedNames[idx]]; +} + +void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx, + DynMemRef *tensor) { + if (tensorDict->orderedNames.capacity() <= idx) + tensorDict->orderedNames.resize(idx + 1); + + // The dynamic memref is essentially anonymous, since we are storing it by + // indexed position. + // TODO: can use random string as names to reduce chance of collision. + auto unique_name = std::to_string(idx); + assert(tensorDict->tensorDict.count(unique_name) == 0 && + "duplicate dynamic mem ref name"); + + tensorDict->orderedNames[idx] = unique_name; + tensorDict->tensorDict[tensorDict->orderedNames[idx]] = tensor; +} + +void *getData(DynMemRef *dynMemRef) { return dynMemRef->data; } + +void setData(DynMemRef *dynMemRef, void *dataPtr) { dynMemRef->data = dataPtr; } + +void *getAlignedData(DynMemRef *dynMemRef) { return dynMemRef->alignedData; } + +void setAlignedData(DynMemRef *dynMemRef, void *dataPtr) { + dynMemRef->alignedData = dataPtr; +} + +INDEX_TYPE *getSizes(DynMemRef *dynMemRef) { return dynMemRef->sizes; } + +void setSizes(DynMemRef *dynMemRef, INDEX_TYPE *sizes) { + for (int i = 0; i < dynMemRef->rank; i++) + dynMemRef->sizes[i] = sizes[i]; +} + +int64_t *getStrides(DynMemRef *dynMemRef) { return dynMemRef->strides; } + +void setStrides(DynMemRef *dynMemRef, int64_t *strides) { + for (int i = 0; i < dynMemRef->rank; i++) + dynMemRef->sizes[i] = strides[i]; +} diff --git a/src/runtime/dyn_memref.h b/src/runtime/dyn_memref.h new file mode 100644 index 0000000..e46f396 --- /dev/null +++ b/src/runtime/dyn_memref.h @@ -0,0 +1,61 @@ +#pragma once + +#include + +typedef int64_t INDEX_TYPE; + +// This is a dynamic version of memref. +// The same struct can be used to represent memrefs of +// all ranks and type combinations. +struct DynMemRef { + void *data; + void *alignedData; + INDEX_TYPE offset; + + unsigned int rank; + INDEX_TYPE *sizes; + int64_t *strides; + + DynMemRef(int _rank); +}; + +extern "C" { + +// Ordered DynMemRef Dictionary is a data structure for wrapping the input +// dynmemrefs so that they can be addressed both by index and by name. +struct OrderedDynMemRefDict; + +// Get number of dynamic memrefs in OrderedDynMemRefDict dict. +int numDynMemRefs(OrderedDynMemRefDict *dict); + +// Create an ordered dynmemref dictionary. +OrderedDynMemRefDict *createOrderedDynMemRefDict(); + +// Create a dynmemref with a certain rank. +DynMemRef *createDynMemRef(int rank); + +// Get the i-th dynmemref from orderedDict. +DynMemRef *getDynMemRef(OrderedDynMemRefDict *orderedDict, int i); + +// Set the i-th dynmemref in orderedDict to be dynMemRef. +void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx, + DynMemRef *dynMemRef); + +// Get data pointer from dynMemRef. +void *getData(DynMemRef *dynMemRef); + +// Set data pointer for dynMemRef. +void setData(DynMemRef *dynMemRef, void *data); + +// Get algined data pointer from dynMemRef. +void *getAlignedData(DynMemRef *); + +// Set aligned data pointer for dynMemRef. +void setAlignedData(DynMemRef *, void *); + +// Get ptr to sizes array. +INDEX_TYPE *getSizes(DynMemRef *); + +// Get ptr to strides array. +int64_t *getStrides(DynMemRef *); +} diff --git a/src/runtime/runtime.cpp b/src/runtime/runtime.cpp new file mode 100644 index 0000000..5a66a3a --- /dev/null +++ b/src/runtime/runtime.cpp @@ -0,0 +1,52 @@ +#include "runtime.hpp" + +ExecutionSession::ExecutionSession(std::string sharedLibPath, + std::string entryPointName) { + _sharedLibraryHandle = dlopen(sharedLibPath.c_str(), RTLD_LAZY); + _entryPointFunc = + (entryPointFuncType)dlsym(_sharedLibraryHandle, entryPointName.c_str()); +} + +std::vector +ExecutionSession::run(std::vector inputsPyArray) { + assert(_entryPointFunc && "entry point not loaded"); + auto *wrappedInput = createOrderedDynMemRefDict(); + int inputIdx = 0; + for (auto inputPyArray : inputsPyArray) { + auto *inputDynMemRef = createDynMemRef(inputPyArray.ndim()); + assert(inputPyArray.flags() && py::array::c_style && + "expect contiguous python array"); + + if (inputPyArray.writeable()) { + inputDynMemRef->data = inputPyArray.mutable_data(); + inputDynMemRef->alignedData = inputPyArray.mutable_data(); + } else { + // If data is not writable, copy them to a writable buffer. + auto *copiedData = (float *)malloc(inputPyArray.nbytes()); + memcpy(copiedData, inputPyArray.data(), inputPyArray.nbytes()); + inputDynMemRef->data = copiedData; + inputDynMemRef->alignedData = copiedData; + } + + for (int i = 0; i < inputPyArray.ndim(); i++) { + inputDynMemRef->sizes[i] = inputPyArray.shape(i); + inputDynMemRef->strides[i] = inputPyArray.strides(i); + } + + setDynMemRef(wrappedInput, inputIdx++, inputDynMemRef); + } + + std::vector outputPyArrays; + auto *wrappedOutput = _entryPointFunc(wrappedInput); + for (int i = 0; i < numDynMemRefs(wrappedOutput); i++) { + auto *dynMemRef = getDynMemRef(wrappedOutput, i); + auto shape = std::vector(dynMemRef->sizes, + dynMemRef->sizes + dynMemRef->rank); + outputPyArrays.emplace_back( + py::array(py::dtype("float32"), shape, dynMemRef->data)); + } + + return outputPyArrays; +} + +ExecutionSession::~ExecutionSession() { dlclose(_sharedLibraryHandle); } diff --git a/src/runtime/runtime.hpp b/src/runtime/runtime.hpp new file mode 100644 index 0000000..cb3d9bc --- /dev/null +++ b/src/runtime/runtime.hpp @@ -0,0 +1,37 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include "src/runtime/dyn_memref.h" + +namespace py = pybind11; + +typedef OrderedDynMemRefDict *(*entryPointFuncType)(OrderedDynMemRefDict *); + +class ExecutionSession { +public: + ExecutionSession(std::string sharedLibPath, std::string entryPointName); + + std::vector run(std::vector inputsPyArray); + + ~ExecutionSession(); + +private: + // Handler to the shared library file being loaded. + void *_sharedLibraryHandle = nullptr; + + // Entry point function. + entryPointFuncType _entryPointFunc = nullptr; +}; + +PYBIND11_MODULE(pyruntime, m) { + py::class_(m, "ExecutionSession") + .def(py::init()) + .def("run", &ExecutionSession::run); +} \ No newline at end of file diff --git a/test/onnx_backend_test.py b/test/onnx_backend_test.py new file mode 100644 index 0000000..5300279 --- /dev/null +++ b/test/onnx_backend_test.py @@ -0,0 +1,162 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import itertools +import os +import unittest +import onnx.backend.base +import onnx.backend.test + +from onnx.backend.base import Device, DeviceType +import onnx.shape_inference +import onnx.version_converter +from typing import Optional, Text, Any, Tuple, Sequence +from onnx import NodeProto, ModelProto, TensorProto +import numpy # type: ignore +import subprocess +from pyruntime import ExecutionSession + +CXX = os.getenv('CXX') +FE = os.getenv('FE') +LLC = os.getenv('LLC') +RT_DIR = os.getenv('RT_DIR') +assert CXX and FE and LLC and RT_DIR, "tools path not set" + +class DummyBackend(onnx.backend.base.Backend): + @classmethod + def prepare( + cls, + model, + device='CPU', + **kwargs + ): + super(DummyBackend, cls).prepare(model, device, **kwargs) + # Save model to disk as temp_model.onnx. + onnx.save(model, "temp_model.onnx") + # Call frontend to process temp_model.onnx, bit code will be generated. + subprocess.run([FE, "temp_model.onnx"], stdout=subprocess.PIPE) + # Call llc to generate object file from bitcode. + subprocess.run([LLC, "-filetype=obj", "model.bc"], + stdout=subprocess.PIPE) + # Generate shared library from object file, linking with c runtime. + subprocess.run([ + CXX, "-shared", "model.o", "-o", "model.so", "-L" + RT_DIR, + "-lcruntime" + ], + stdout=subprocess.PIPE) + return ExecutionSession("./model.so", "_dyn_entry_point_main_graph") + + @classmethod + def supports_device(cls, device): + d = Device(device) + if d.type == DeviceType.CPU: + return True + return False + + +backend_test = onnx.backend.test.BackendTest(DummyBackend, __name__) + +# Test directories: +# https://github.com/onnx/onnx/tree/master/onnx/backend/test/data/node + +test_to_enable = [ + # Add Op: + "test_add_cpu", + "test_add_bcast_cpu", + + # And Op: + + # Sub Op: + "test_sub_cpu", + "test_sub_bcast_cpu", + "test_sub_example_cpu", + + # Cosh Op: + "test_cosh_cpu", + "test_cosh_example_cpu", + + # Div Op: + "test_div_cpu", + "test_div_bcast_cpu", + "test_div_example_cpu", + + # Elu Op: + "test_elu_cpu", + "test_elu_default_cpu", + "test_elu_example_cpu", + + # Exp Op: + "test_exp_cpu", + "test_exp_example_cpu", + + # Hard Sigmoid Op: + "test_hardsigmoid_cpu", + "test_hardsigmoid_default_cpu", + "test_hardsigmoid_example_cpu", + + # Leaky Relu Op: + "test_leakyrelu_cpu", + "test_leakyrelu_default_cpu", + "test_leakyrelu_example_cpu", + + # Max Op: + # "test_max_example_cpu", <- error + "test_max_one_input_cpu", + # "test_max_two_inputs_cpu", <- error + + # Min Op: + # "test_min_example_cpu", <- error + "test_min_one_input_cpu", + # "test_min_two_inputs_cpu", <- error + + # Mul Op: + "test_mul_cpu", + "test_mul_bcast_cpu", + "test_mul_example_cpu", + + # Relu Op: + "test_relu_cpu", + + # Selu Op: + "test_selu_cpu", + "test_selu_default_cpu", + "test_selu_example_cpu", + + # Sigmoid Op: + "test_sigmoid_cpu", + "test_sigmoid_example_cpu", + + # Sum Op: + #"test_sum_example_cpu", <- error + "test_sum_one_input_cpu", + #"test_sum_two_inputs_cpu", <- error + + # Reciprocal Op: + #"test_reciprocal_cpu", <- error on shape inference. + #"test_reciprocal_example_cpu", <- error on shape inference. +] + +# Extract name of all test cases. +import inspect +all_tests = inspect.getmembers( + backend_test.test_cases["OnnxBackendNodeModelTest"]) +all_test_names = list(map(lambda x: x[0], all_tests)) +for test_name in test_to_enable: + assert test_name in all_test_names, "test name {} not found".format(test_name) + backend_test.include(r"^{}$".format(test_name)) + + +def tearDownModule(): + print() + print("*" * 40) + print("A total of {} tests should have run".format(len(test_to_enable))) + print("*" * 40) + + +# import all test cases at global scope to make them visible to python.unittest +globals().update(backend_test.test_cases) + +if __name__ == '__main__': + unittest.main()