Enable ONNX Backend Test (#1)

* wip, commit before merging with upstream

* organize API, return wrapped output

* enable onnx backend test

* undo unintentional commit

* fix krnl ops tablegen

* format krnl ops

* reorder fillDynMemRefWithMemRef to be after fillPtrToMemRefWithDynMemRef, better comments

* more onnx backend tests

* ensure that test names refer to existing tests

* improve code readability by shortening type names

* nit

* restore unintentional changes

* more nits

* fix ; -> :

* split runtime implementation into header and body file, add support for data types

* comment on the onnx backend test

* make the comments read better

* do not dump when lowering
This commit is contained in:
Tian Jin 2019-12-22 00:25:02 -05:00
parent 5573cb39fe
commit 685bf23b40
22 changed files with 972 additions and 106 deletions

View File

@ -26,6 +26,7 @@ add_subdirectory(third_party/pybind11)
set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD 14)
add_subdirectory(src/builder) add_subdirectory(src/builder)
add_subdirectory(src/compiler) add_subdirectory(src/compiler)
add_subdirectory(src/runtime)
add_subdirectory(src) add_subdirectory(src)
add_subdirectory(test) add_subdirectory(test)

View File

@ -34,7 +34,7 @@ set(MLIR_SRC_INCLUDE_PATH ${LLVM_SRC}/projects/mlir/include)
set(MLIR_BIN_INCLUDE_PATH ${LLVM_BUILD}/projects/mlir/include) set(MLIR_BIN_INCLUDE_PATH ${LLVM_BUILD}/projects/mlir/include)
set(MLIR_TOOLS_DIR ${LLVM_BUILD}/bin) set(MLIR_TOOLS_DIR ${LLVM_BUILD}/bin)
set(ONNF_TOOLS_DIR ${ONNF_BIN_ROOT}/src/compiler/tool/onnf_opt) set(ONNF_TOOLS_DIR ${ONNF_BIN_ROOT}/bin)
set(ONNF_LIT_TEST_SRC_DIR ${CMAKE_SOURCE_DIR}/test/mlir) set(ONNF_LIT_TEST_SRC_DIR ${CMAKE_SOURCE_DIR}/test/mlir)
set(ONNF_LIT_TEST_BUILD_DIR ${CMAKE_BINARY_DIR}/test/mlir) set(ONNF_LIT_TEST_BUILD_DIR ${CMAKE_BINARY_DIR}/test/mlir)

View File

@ -614,7 +614,6 @@ private:
onnx::NodeProto node, int nIn, int nOut, onnx::NodeProto node, int nIn, int nOut,
std::initializer_list<std::tuple<std::string, std::string, std::string>> std::initializer_list<std::tuple<std::string, std::string, std::string>>
attrs) { attrs) {
// Conv has attribute dilations, kernel_shape, pads, the default value of // Conv has attribute dilations, kernel_shape, pads, the default value of
// which is determined by the shape of first argument. However, since the // which is determined by the shape of first argument. However, since the
// shape is unknown now, these attributes can be not generated auto // shape is unknown now, these attributes can be not generated auto
@ -686,7 +685,7 @@ private:
} }
void ImportGraph(const onnx::GraphProto &graph, void ImportGraph(const onnx::GraphProto &graph,
const std::string &name = "main") { const std::string &name = "main_graph") {
// create a function for the graph // create a function for the graph
// TODO: // TODO:
// * get name and type for the function. // * get name and type for the function.
@ -699,13 +698,18 @@ private:
} }
// TODO: import the initializer // TODO: import the initializer
auto func_type = builder_.getFunctionType(arg_types, {}); auto funcType = builder_.getFunctionType(arg_types, {});
auto main_func = auto mainFunc =
mlir::FuncOp::create(UnknownLoc(), name, func_type, /* attrs = */ {}); mlir::FuncOp::create(UnknownLoc(), name, funcType, /* attrs = */ {});
auto &entryBlock = *main_func.addEntryBlock(); auto entryPoint = mlir::ONNXEntryPointOp::create(
UnknownLoc(), mainFunc, /*numInputs=*/graph.input().size(),
/*numOutputs=*/graph.output().size());
auto &entryBlock = *mainFunc.addEntryBlock();
builder_.setInsertionPointToStart(&entryBlock); builder_.setInsertionPointToStart(&entryBlock);
module_.push_back(main_func);
module_.push_back(mainFunc);
module_.push_back(entryPoint);
for (auto it : llvm::zip(graph.input(), entryBlock.getArguments())) { for (auto it : llvm::zip(graph.input(), entryBlock.getArguments())) {
ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it)); ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it));
@ -728,8 +732,8 @@ private:
builder_.create<mlir::ReturnOp>(UnknownLoc(), ret_vals); builder_.create<mlir::ReturnOp>(UnknownLoc(), ret_vals);
// Update main function signature to reflect types of newly imported // Update main function signature to reflect types of newly imported
// output tensors. // output tensors.
func_type = builder_.getFunctionType(arg_types, ret_types); funcType = builder_.getFunctionType(arg_types, ret_types);
main_func.setType(func_type); mainFunc.setType(funcType);
} }
}; // FrontendGenImpl class }; // FrontendGenImpl class
} // namespace } // namespace

View File

@ -21,7 +21,7 @@
namespace mlir { namespace mlir {
class MLIRContext; class MLIRContext;
class OwningModuleRef; class OwningModuleRef;
} // namespace mlir } // namespace mlir
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Import a model into one of ONNF's frontend models. // Import a model into one of ONNF's frontend models.
@ -41,7 +41,8 @@ mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model);
* @return MLIR::module generated for the ONNX model. * @return MLIR::module generated for the ONNX model.
*/ */
void ImportFrontendModelFile(std::string model_fname, void ImportFrontendModelFile(std::string model_fname,
mlir::MLIRContext& context, mlir::OwningModuleRef& module); mlir::MLIRContext &context,
mlir::OwningModuleRef &module);
/*! /*!
* TODO: Import models into other extension dialects that cover the * TODO: Import models into other extension dialects that cover the

View File

@ -0,0 +1,5 @@
add_library(DLCAnalysis STATIC
extract_integer_set.cpp)
target_include_directories(DLCAnalysis PRIVATE ${DLC_SRC_ROOT})
target_include_directories(DLCAnalysis PRIVATE ${DLC_BIN_ROOT})

View File

@ -364,6 +364,14 @@ ParseResult parseKrnlReturnLoopsOp(OpAsmParser &parser,
return success(); return success();
} }
void KrnlEntryPointOp::build(mlir::Builder *builder, OperationState &state,
SymbolRefAttr funcAttr, IntegerAttr numInputs,
IntegerAttr numOutputs) {
state.addAttribute(KrnlEntryPointOp::getEntryPointFuncAttrName(), funcAttr);
state.addAttribute(KrnlEntryPointOp::getNumInputsAttrName(), numInputs);
state.addAttribute(KrnlEntryPointOp::getNumOutputsAttrName(), numOutputs);
}
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "src/compiler/krnl.cpp.inc" #include "src/compiler/krnl.cpp.inc"
} // namespace mlir } // namespace mlir

View File

@ -11,6 +11,7 @@
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"

View File

@ -8,35 +8,27 @@
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
def Krnl_Dialect : Dialect { def Krnl_Dialect : Dialect {
let name = "krnl"; let name = "krnl";
let cppNamespace = ""; let cppNamespace = "";
} }
// Require regions to have krnl.terminate terminator operation. // Require regions to have krnl.terminate terminator operation.
def ImplicitKrnlTerminator def ImplicitKrnlTerminator : SingleBlockImplicitTerminator<"KrnlTerminatorOp">;
: SingleBlockImplicitTerminator<"KrnlTerminatorOp">;
def KrnlDefineLoopsOp : Op<Krnl_Dialect, "define_loops"> { def KrnlDefineLoopsOp : Op<Krnl_Dialect, "define_loops"> {
let summary = "define_loops operation"; let summary = "define_loops operation";
let description = [{ let description = [{
The "krnl.define_loops" operation is used to define input loops, The "krnl.define_loops" operation is used to define input loops,
those are the for loops appearing in the input program that we those are the for loops appearing in the input program that we
intend to optimize. intend to optimize.
}]; }];
let arguments = (ins); let arguments = (ins);
let results = (outs Variadic<AnyType>); let results = (outs Variadic<AnyType>);
let skipDefaultBuilders = 1; let skipDefaultBuilders = 1;
let builders = [ OpBuilder<"Builder *builder, OperationState &result,"
let builders = [ "int64_t num_loops"> ];
OpBuilder<"Builder *builder, OperationState &result,"
"int64_t num_loops">
];
let printer = [{ return ::print(p, *this); }]; let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }]; let parser = [{ return ::parse$cppClass(parser, result); }];
@ -44,33 +36,27 @@ def KrnlDefineLoopsOp : Op<Krnl_Dialect, "define_loops"> {
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static StringRef getNumLoopsAttrName() { return "num_loops"; } static StringRef getNumLoopsAttrName() { return "num_loops"; }
// Helper function to extract the number of loops being defined. // Helper function to extract the number of loops being defined.
int64_t getNumLoops() { int64_t getNumLoops() {
auto num_loops = auto num_loops = getAttrOfType<IntegerAttr>(getNumLoopsAttrName())
getAttrOfType<IntegerAttr>( .getValue()
getNumLoopsAttrName()) .getSExtValue();
.getValue() return num_loops;
.getSExtValue(); }
return num_loops; }];
}
}];
} }
def KrnlOptimizeLoopsOp : Op<Krnl_Dialect, "optimize_loops"> { def KrnlOptimizeLoopsOp : Op<Krnl_Dialect, "optimize_loops"> {
let summary = "optimize_loops operation"; let summary = "optimize_loops operation";
let description = [{ let description = [{
The "krnl.optimize_loops" operation is essentially a cosmetic operation The "krnl.optimize_loops" operation is essentially a cosmetic operation
which exists to encapsulate a region where loops are being scheduled/optimized. which exists to encapsulate a region where loops are being scheduled /
optimized.
The optimized loops are returned at the end of the The optimized loops are returned at the end of the region associated with
region associated with the krnl.optimize_loops operation. the krnl.optimize_loops operation.
For example:
TBD once we have actual schedule intrinsics.
For example : TBD once we have actual schedule intrinsics.
}]; }];
let arguments = (ins Variadic<AnyType>); let arguments = (ins Variadic<AnyType>);
@ -79,10 +65,8 @@ def KrnlOptimizeLoopsOp : Op<Krnl_Dialect, "optimize_loops"> {
let skipDefaultBuilders = 1; let skipDefaultBuilders = 1;
let builders = [ let builders = [ OpBuilder<"Builder *builder, OperationState &result, "
OpBuilder<"Builder *builder, OperationState &result, " "int timestamp_space_rank"> ];
"int timestamp_space_rank">
];
let printer = [{ return ::print(p, *this); }]; let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }]; let parser = [{ return ::parse$cppClass(parser, result); }];
@ -91,7 +75,6 @@ def KrnlOptimizeLoopsOp : Op<Krnl_Dialect, "optimize_loops"> {
def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> { def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> {
let summary = "iterate operation"; let summary = "iterate operation";
let description = [{ let description = [{
The "krnl.iterate" operation is conceptually equivalent to a nested for loops. The "krnl.iterate" operation is conceptually equivalent to a nested for loops.
For instance, say we have the following two For instance, say we have the following two
@ -103,25 +86,20 @@ def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> {
Then, consider the following krnl.iterate operation: Then, consider the following krnl.iterate operation:
krnl.iterate (%o0, %o1) with (%l0 -> %i0 = 0 to 10, %l1 -> %i1 = 0 to 10) { krnl.iterate (%o0, %o1) with (%l0 -> %i0 = 0 to 10, %l1 -> %i1 = 0 to 10) {
// Some operations. // Some operations.
} }
It is equivalent to: It is equivalent to:
for (i0=0; i0<10; i0++) for (i0 = 0; i0 < 10; i0++)
for (i1=0; i1<10; i1++) for (i1 = 0; i1 < 10; i1++)
// Some operations. // Some operations.
}]; }];
let arguments = (ins Variadic<AnyType>); let arguments = (ins Variadic<AnyType>);
let regions = (region SizedRegion<1>:$bodyRegion); let regions = (region SizedRegion<1>:$bodyRegion);
let skipDefaultBuilders = 1; let skipDefaultBuilders = 1;
let builders = [ OpBuilder<"Builder *builder, OperationState &result, "
let builders = [ "KrnlIterateOperandPack operandPack"> ];
OpBuilder<"Builder *builder, OperationState &result, "
"KrnlIterateOperandPack operandPack">
];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
// In krnl.iterate operation, operands are stored as such // In krnl.iterate operation, operands are stored as such
@ -134,20 +112,19 @@ def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> {
int64_t getNumOptimizedLoops() { int64_t getNumOptimizedLoops() {
auto num_optimized_loops = auto num_optimized_loops =
getAttrOfType<IntegerAttr>( getAttrOfType<IntegerAttr>(getNumOptimizedLoopsAttrName())
getNumOptimizedLoopsAttrName()) .getValue()
.getValue() .getSExtValue();
.getSExtValue();
return num_optimized_loops; return num_optimized_loops;
} }
// Get name of the attribute for storing bound represented using affine maps. // Get name of the attribute for storing bound represented using affine maps.
static StringRef getBoundsAttrName() { return "bounds"; } static StringRef getBoundsAttrName() { return "bounds"; }
}]; }];
let printer = [{ return ::print(p, *this); }]; let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }]; let parser = [{ return ::parse$cppClass(parser, result); }];
let verifier = [{ return ::verify(*this); }]; let verifier = [{ return ::verify(*this); }];
} }
def KrnlReturnLoopsOp : Op<Krnl_Dialect, "return_loops", [Terminator]> { def KrnlReturnLoopsOp : Op<Krnl_Dialect, "return_loops", [Terminator]> {
@ -182,11 +159,30 @@ def KrnlTerminatorOp : Op<Krnl_Dialect, "terminate", [Terminator]> {
let verifier = ?; let verifier = ?;
} }
def KrnlEntryPointOp : Op<Krnl_Dialect, "entry_point"> {
let summary = "Indicate ONNX entry point";
let description = [{The "krnl.entry_point" function indicates the main entry
point of ONNX model.}];
let builders = [ OpBuilder<"Builder *builder, OperationState &result, "
"SymbolRefAttr funcAttr, IntegerAttr numInputs, "
"IntegerAttr numOutputs"> ];
let extraClassDeclaration = [{
static StringRef getEntryPointFuncAttrName() { return "func"; }
static StringRef getNumInputsAttrName() { return "numInputs"; }
static StringRef getNumOutputsAttrName() { return "numOutputs"; }
}];
// No custom parsing/printing form.
let parser = ?;
let printer = ?;
}
def KrnlMemcpyOp : Op<Krnl_Dialect, "memcpy"> { def KrnlMemcpyOp : Op<Krnl_Dialect, "memcpy"> {
let summary = "Krnl memcpy operation"; let summary = "Krnl memcpy operation";
let description = [{ let description = [{
In the KRNL dialect the reshape op doesn't generate a new memory entry and In the KRNL dialect the reshape op
treats a reshape like a cast. doesn't generate a new memory entry and treats a reshape like a cast.
}]; }];
let arguments = (ins AnyMemRef:$dest, AnyMemRef:$src, AnyInteger:$size); let arguments = (ins AnyMemRef:$dest, AnyMemRef:$src, AnyInteger:$size);

View File

@ -58,6 +58,27 @@ class ONNX_Op<string mnemonic, list<OpTrait> traits = []> :
include "dialect/onnx/onnxop.inc" include "dialect/onnx/onnxop.inc"
// Indicate entry point functions of ONNX graph.
def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
let summary = "Indicate ONNX entry point";
let description = [{
The "onnx.EntryPoint" function indicates the main entry point of ONNX model.
}];
let builders = [OpBuilder<[{Builder *builder, OperationState &state,
FuncOp function, int numInputs, int numOutputs}]>];
let extraClassDeclaration = [{
static ONNXEntryPointOp create(Location location, FuncOp& func,
int numInputs, int numOutputs);
static StringRef getEntryPointFuncAttrName() { return "func"; }
static StringRef getNumInputsAttrName() { return "numInputs"; }
static StringRef getNumOutputsAttrName() { return "numOutputs"; }
}];
}
def ONNXFullGemmOp: ONNX_Op<"FullGemm", def ONNXFullGemmOp: ONNX_Op<"FullGemm",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX general matrix multiply operation"; let summary = "ONNX general matrix multiply operation";

View File

@ -7,7 +7,6 @@
// This file defines ONNX operations in the MLIR operation set. // This file defines ONNX operations in the MLIR operation set.
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/Dialect/Traits.h" #include "mlir/Dialect/Traits.h"
#include "mlir/IR/Block.h" #include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
@ -38,6 +37,28 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx)
>(); >();
} }
void ONNXEntryPointOp::build(mlir::Builder *builder,
mlir::OperationState &state, mlir::FuncOp function,
int numInputs, int numOutputs) {
state.addAttribute(ONNXEntryPointOp::getEntryPointFuncAttrName(),
builder->getSymbolRefAttr(function));
state.addAttribute(ONNXEntryPointOp::getNumInputsAttrName(),
builder->getI32IntegerAttr(numInputs));
state.addAttribute(ONNXEntryPointOp::getNumOutputsAttrName(),
builder->getI32IntegerAttr(numOutputs));
}
ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location,
mlir::FuncOp &func, int numInputs,
int numOutputs) {
mlir::OperationState state(location, "onnx.EntryPoint");
Builder builder(location->getContext());
mlir::ONNXEntryPointOp::build(&builder, state, func, numInputs, numOutputs);
Operation *op = mlir::Operation::create(state);
auto onnxEntryOp = llvm::cast<mlir::ONNXEntryPointOp>(op);
return onnxEntryOp;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ONNX Operations // ONNX Operations
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -10,6 +10,7 @@
#pragma once #pragma once
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"

View File

@ -8,7 +8,6 @@
// Krnl IR and standard operations. // Krnl IR and standard operations.
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include <map> #include <map>
#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/AffineOps/AffineOps.h"
@ -884,6 +883,27 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
} }
}; };
//===----------------------------------------------------------------------===//
// EntryPoint Op lowering to Krnl Entry Point.
//===----------------------------------------------------------------------===//
class ONNXEntryPointLowering : public OpRewritePattern<ONNXEntryPointOp> {
public:
using OpRewritePattern<ONNXEntryPointOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(ONNXEntryPointOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<KrnlEntryPointOp>(
op,
op.getAttrOfType<SymbolRefAttr>(
ONNXEntryPointOp::getEntryPointFuncAttrName()),
op.getAttrOfType<IntegerAttr>(ONNXEntryPointOp::getNumInputsAttrName()),
op.getAttrOfType<IntegerAttr>(
ONNXEntryPointOp::getNumOutputsAttrName()));
return matchSuccess();
}
};
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Conversion from Tensor type to the Standard dialect MemRef type. // Conversion from Tensor type to the Standard dialect MemRef type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -985,7 +1005,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
ONNXReshapeOpLowering>(&getContext()); ONNXReshapeOpLowering, ONNXEntryPointLowering>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the // With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal` // conversion. The conversion will signal failure if any of our `illegal`

View File

@ -143,14 +143,16 @@ void KrnlToAffineLoweringPass::runOnFunction() {
// We expect IR to be free of Krnl Dialect Ops. // We expect IR to be free of Krnl Dialect Ops.
target.addIllegalDialect<KrnlOpsDialect>(); target.addIllegalDialect<KrnlOpsDialect>();
target.addLegalOp<KrnlMemcpyOp>(); target.addLegalOp<KrnlMemcpyOp>();
target.addLegalOp<KrnlEntryPointOp>();
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering, patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
KrnlDefineLoopsLowering, KrnlOptimizeLoopsLowering>( KrnlDefineLoopsLowering, KrnlOptimizeLoopsLowering>(
&getContext()); &getContext());
if (failed(applyPartialConversion(getFunction(), target, patterns))) if (failed(applyPartialConversion(getFunction(), target, patterns))) {
signalPassFailure(); signalPassFailure();
}
} }
} // namespace } // namespace

View File

@ -4,7 +4,6 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "llvm/ADT/Sequence.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
@ -15,6 +14,7 @@
#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/Sequence.h"
#include "src/compiler/dialect/krnl/krnl_ops.hpp" #include "src/compiler/dialect/krnl/krnl_ops.hpp"
#include "src/compiler/pass/passes.hpp" #include "src/compiler/pass/passes.hpp"
@ -23,20 +23,39 @@ using namespace mlir;
namespace { namespace {
static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName,
ModuleOp module,
mlir::LLVM::LLVMType funcType,
PatternRewriter &rewriter) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>(funcName)) {
auto symbolRef = SymbolRefAttr::get(funcName, context);
assert(symbolRef.getType() == funcType && "wrong symbol type");
return symbolRef;
}
// Insert the function into the body of the parent module.
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), funcName, funcType);
return SymbolRefAttr::get(funcName, context);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// KRNL to LLVM: patterns which need a direct lowering to LLVM. // KRNL to LLVM: KrnlMemcpyOpLowering
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
class KrnlMemcpyOpLowering : public ConversionPattern { class KrnlMemcpyOpLowering : public ConversionPattern {
public: public:
explicit KrnlMemcpyOpLowering(MLIRContext* context) explicit KrnlMemcpyOpLowering(MLIRContext *context)
: ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {} : ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands, PatternMatchResult
ConversionPatternRewriter& rewriter) const override { matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
auto* context = op->getContext(); ConversionPatternRewriter &rewriter) const override {
auto *context = op->getContext();
auto loc = op->getLoc(); auto loc = op->getLoc();
auto* llvmDialect = auto *llvmDialect =
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>(); op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
assert(llvmDialect && "expected llvm dialect to be registered"); assert(llvmDialect && "expected llvm dialect to be registered");
@ -47,39 +66,40 @@ class KrnlMemcpyOpLowering : public ConversionPattern {
// First operand. // First operand.
Type dstType = Type dstType =
operands[0]->getType().cast<LLVM::LLVMType>().getStructElementType(1); operands[0]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
Value* alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>( Value *alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
loc, dstType, operands[0], rewriter.getI64ArrayAttr(1)); loc, dstType, operands[0], rewriter.getI64ArrayAttr(1));
Value* alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>( Value *alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory); loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory);
// Second operand. // Second operand.
Type srcType = Type srcType =
operands[1]->getType().cast<LLVM::LLVMType>().getStructElementType(1); operands[1]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
Value* alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>( Value *alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
loc, srcType, operands[1], rewriter.getI64ArrayAttr(1)); loc, srcType, operands[1], rewriter.getI64ArrayAttr(1));
Value* alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>( Value *alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory); loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory);
// Size. // Size.
Value* int64Size = rewriter.create<LLVM::SExtOp>( Value *int64Size = rewriter.create<LLVM::SExtOp>(
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]); loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]);
// Memcpy call // Memcpy call
rewriter.create<CallOp>(loc, memcpyRef, rewriter.create<CallOp>(
LLVM::LLVMType::getVoidTy(llvmDialect), loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
ArrayRef<Value*>( ArrayRef<Value *>(
{alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size})); {alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size}));
rewriter.eraseOp(op); rewriter.eraseOp(op);
return matchSuccess(); return matchSuccess();
} }
private: private:
/// Return a symbol reference to the memcpy function, inserting it into the /// Return a symbol reference to the memcpy function, inserting it into the
/// module if necessary. /// module if necessary.
static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter& rewriter, static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter,
ModuleOp module, LLVM::LLVMDialect* llvmDialect) { ModuleOp module,
auto* context = module.getContext(); LLVM::LLVMDialect *llvmDialect) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64")) if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64"))
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context); return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
// Create a function declaration for memcpy, the signature is: // Create a function declaration for memcpy, the signature is:
@ -87,19 +107,355 @@ class KrnlMemcpyOpLowering : public ConversionPattern {
auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect); auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect);
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy, auto llvmFnType = LLVM::LLVMType::getFunctionTy(
llvmVoidTy,
ArrayRef<mlir::LLVM::LLVMType>({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty}), ArrayRef<mlir::LLVM::LLVMType>({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty}),
false); false);
// Insert the memcpy function into the body of the parent module. // Insert the memcpy function into the body of the parent module.
PatternRewriter::InsertionGuard insertGuard(rewriter); PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody()); rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>( rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(),
module.getLoc(), "llvm.memcpy.p0i8.p0i8.i64", llvmFnType); "llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context); return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
} }
}; };
} // end namespace
//===----------------------------------------------------------------------===//
// KRNL to LLVM: KrnlEntryPointOp
//===----------------------------------------------------------------------===//
class KrnlEntryPointOpLowering : public OpRewritePattern<KrnlEntryPointOp> {
public:
using OpRewritePattern<KrnlEntryPointOp>::OpRewritePattern;
enum class API {
CREATE_ORDERED_DYN_MEM_REF_DICT,
CREATE_DYN_MEM_REF,
GET_DYN_MEM_REF,
SET_DYN_MEM_REF,
GET_DATA,
SET_DATA,
GET_SIZES,
GET_STRIDES,
};
struct ApiSpec {
API id;
std::string name;
FlatSymbolRefAttr symbolRef;
LLVM::LLVMType outputTy;
SmallVector<LLVM::LLVMType, 4> inputTys;
ApiSpec(API id, const std::string &name, LLVM::LLVMType outputTy,
ArrayRef<LLVM::LLVMType> inputTys)
: id(id), name(name), outputTy(outputTy),
inputTys(inputTys.begin(), inputTys.end()) {}
LLVM::LLVMType funcTy() {
return LLVM::LLVMType::getFunctionTy(outputTy, inputTys,
/*isVarArg=*/false);
}
};
PatternMatchResult matchAndRewrite(KrnlEntryPointOp op,
PatternRewriter &rewriter) const override {
auto *llvmDialect =
op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
assert(llvmDialect && "expected llvm dialect to be registered");
auto module = op.getParentOfType<ModuleOp>();
auto apiRegistry = RegisterAllApis(module, rewriter, llvmDialect);
auto loc = op.getLoc();
auto numOutputs =
op.getAttrOfType<IntegerAttr>(KrnlEntryPointOp::getNumOutputsAttrName())
.getInt();
using LLVMType = LLVM::LLVMType;
auto opaquePtrTy = LLVMType::getInt8PtrTy(llvmDialect);
auto int32Ty = LLVMType::getInt32Ty(llvmDialect);
// Rewrite Krnl Entry Point Operation to an LLVM function with a dynamic
// signature. The signature is dynamic because it remains the same no matter
// what the model input/output schema look like. Such dynamic signature
// takes a opaque ptr as input, representing a ptr to a data structure
// containing a set of dynamic memrefs wrapped in a vector; similarly the
// output is also a opaque ptr to a data structure with output memrefs
// wrapped within it.
auto staticEntryPointFuncName =
op.getAttrOfType<SymbolRefAttr>(
KrnlEntryPointOp::getEntryPointFuncAttrName())
.getLeafReference();
auto dynEntryPointName = "_dyn_entry_point_" + staticEntryPointFuncName;
assert(module.lookupSymbol(dynEntryPointName.str()) == nullptr &&
"dynamic entry point name is not unique");
rewriter.eraseOp(op);
auto dynEntryPointFuncTy =
LLVMType::getFunctionTy(opaquePtrTy, {opaquePtrTy}, false);
auto dynamicEntryPointFunc = rewriter.create<LLVM::LLVMFuncOp>(
loc, dynEntryPointName.str(), dynEntryPointFuncTy);
auto &entryPointEntryBlock =
createEntryBlock(dynEntryPointFuncTy, dynamicEntryPointFunc);
rewriter.setInsertionPointToStart(&entryPointEntryBlock);
// Based on the static entry point type signature, unpack dynamic memory
// refs to corresponding static memory refs.
auto *staticEntryPointFunc = module.lookupSymbol(staticEntryPointFuncName);
assert(staticEntryPointFunc &&
isa<LLVM::LLVMFuncOp>(staticEntryPointFunc) &&
"entry point func must exist and be an llvm func op");
auto staticEntryPointTy = dyn_cast<LLVM::LLVMFuncOp>(staticEntryPointFunc)
.getType()
.dyn_cast<LLVMType>();
// Retrieve dynamic mem refs from wrapped input, and convert every one of
// them to static mem refs.
SmallVector<Value *, 4> staticInputs;
auto wrappedInput = entryPointEntryBlock.getArgument(0);
for (size_t i = 0; i < staticEntryPointTy.getFunctionNumParams(); i++) {
// Call API function to retrieve the i-th dynamic memref.
auto idxVal = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(i));
auto dynMemRef = callApi(rewriter, loc, apiRegistry, API::GET_DYN_MEM_REF,
{wrappedInput, idxVal});
// Create a (static) memref type corresponding to the i-th memref input to
// the inference function on stack, and load it to memRef.
auto memRefPtrTy = staticEntryPointTy.getFunctionParamType(i);
auto memRefTy = memRefPtrTy.getPointerElementTy();
auto one = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(1));
Value *ptrToMemRef =
rewriter.create<LLVM::AllocaOp>(loc, memRefPtrTy, one,
/*alignment=*/0);
// Fill in the memref underlying ptrToMemRef with information extracted
// from dynMemRef.
fillPtrToMemRefWithDynMemRef(*dynMemRef, *ptrToMemRef, rewriter, loc,
apiRegistry, llvmDialect);
// ptrToMemRef will be an input to main computation graph function.
staticInputs.emplace_back(ptrToMemRef);
}
// If more than one output exists, the struct becomes a nested struct,
// the unpacking logic can be more involved, so no support for now.
assert(numOutputs == 1 && "only support 1 output tensor now.");
// Call static entry point with the memref ptrs created, and get output.
auto outputMemRefs = rewriter.create<LLVM::CallOp>(
loc, staticEntryPointTy.getFunctionResultType(),
rewriter.getSymbolRefAttr(staticEntryPointFuncName), staticInputs);
// Create wrapped output.
auto wrappedOutput = callApi(rewriter, loc, apiRegistry,
API::CREATE_ORDERED_DYN_MEM_REF_DICT, {});
// Get the first memref returned, convert to a dynamic memref and store
// it in the wrapped Output.
auto outMemRef = outputMemRefs.getResult(0);
auto outMemRefTy = outMemRef->getType().dyn_cast<LLVMType>();
auto outMemRefRank =
outMemRefTy.getStructElementType(3).getArrayNumElements();
auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
fillDynMemRefWithMemRef(*outMemRef, *outDynMemRef, rewriter, loc,
apiRegistry, llvmDialect);
auto zero = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(0));
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
{wrappedOutput, zero, outDynMemRef});
// Return wrapped output.
rewriter.create<LLVM::ReturnOp>(loc,
SmallVector<Value *, 1>({wrappedOutput}));
return matchSuccess();
}
private:
using ApiRegistry = std::map<API, ApiSpec>;
ApiRegistry RegisterAllApis(ModuleOp &module, PatternRewriter &rewriter,
LLVM::LLVMDialect *llvmDialect) const {
using LLVMType = LLVM::LLVMType;
auto voidTy = LLVMType::getVoidTy(llvmDialect);
auto opaquePtrTy = LLVMType::getInt8PtrTy(llvmDialect);
auto int32Ty = LLVMType::getInt32Ty(llvmDialect);
auto int64Ty = LLVMType::getInt64Ty(llvmDialect);
auto int64PtrTy = int64Ty.getPointerTo();
// Declare API type as an enum value, its string name and an LLVM Type
// specifying its signature.
// clang-format off
std::vector<ApiSpec> apiSpecs = {
ApiSpec(API::CREATE_ORDERED_DYN_MEM_REF_DICT, "createOrderedDynMemRefDict", opaquePtrTy, {}),
ApiSpec(API::CREATE_DYN_MEM_REF, "createDynMemRef", opaquePtrTy, {int32Ty}),
ApiSpec(API::GET_DATA, "getData", opaquePtrTy, {opaquePtrTy}),
ApiSpec(API::SET_DATA, "setData", voidTy, {opaquePtrTy, opaquePtrTy}),
ApiSpec(API::GET_DYN_MEM_REF, "getDynMemRef", opaquePtrTy, {opaquePtrTy, int32Ty}),
ApiSpec(API::SET_DYN_MEM_REF, "setDynMemRef", voidTy, {opaquePtrTy, int32Ty, opaquePtrTy}),
ApiSpec(API::GET_SIZES, "getSizes", int64PtrTy, {opaquePtrTy}),
ApiSpec(API::GET_STRIDES, "getStrides", int64PtrTy, {opaquePtrTy})
};
// clang-format on
// Declare APIs in the current module and build an API registry mapping api
// identities to a symbol reference to the API function.
ApiRegistry registry;
for (auto &apiSpec : apiSpecs) {
apiSpec.symbolRef = getOrInsertExternFunc(apiSpec.name, module,
apiSpec.funcTy(), rewriter);
registry.emplace(apiSpec.id, apiSpec);
}
return registry;
}
// Call a registered API, return the return SSA values if only one result is
// returned, otherwise return nullptr.
Value *callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry,
API apiId, ArrayRef<Value *> params) const {
auto returnVals = rewriter.create<LLVM::CallOp>(
loc, registry.at(apiId).outputTy, registry.at(apiId).symbolRef,
ArrayRef<Value *>(params));
if (returnVals.getNumResults() == 1)
return returnVals.getResult(0);
return nullptr;
}
// Helper function to insert an entry block to LLVM function.
// (TODO): upstream this to MLIR.
Block &createEntryBlock(LLVM::LLVMType &dynEntryPointFuncType,
LLVM::LLVMFuncOp &dynamicEntryPointFunc) const {
// Add entry block:
auto *entryPointEntryBlock = new Block();
dynamicEntryPointFunc.push_back(entryPointEntryBlock);
llvm::SmallVector<Type, 4> argTypes;
for (size_t i = 0; i < dynEntryPointFuncType.getFunctionNumParams(); i++)
argTypes.emplace_back(dynEntryPointFuncType.getFunctionParamType(i));
entryPointEntryBlock->addArguments(argTypes);
return *entryPointEntryBlock;
}
void fillPtrToMemRefWithDynMemRef(Value &dynMemRef, Value &ptrToMemRef,
PatternRewriter &rewriter,
const Location &loc,
const std::map<API, ApiSpec> &apiRegistry,
LLVM::LLVMDialect *llvmDialect) const {
auto memRefPtrTy = ptrToMemRef.getType().dyn_cast<LLVM::LLVMType>();
auto memRefTy = memRefPtrTy.getPointerElementTy();
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
Value *memRef =
rewriter.create<LLVM::LoadOp>(loc, memRefPtrTy, &ptrToMemRef);
// Set dataPtr and alignedDataPtr;
auto dataPtr =
callApi(rewriter, loc, apiRegistry, API::GET_DATA, {&dynMemRef});
dataPtr = rewriter.create<LLVM::BitcastOp>(
loc, memRefTy.getStructElementType(0), dataPtr);
memRef = rewriter.create<LLVM::InsertValueOp>(
loc, memRefTy, memRef, dataPtr,
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(0)}));
memRef = rewriter.create<LLVM::InsertValueOp>(
loc, memRefTy, memRef, dataPtr,
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(1)}));
// Use zero offset now.
auto zero = rewriter.create<LLVM::ConstantOp>(
loc, int64Ty, rewriter.getI64IntegerAttr(0));
memRef = rewriter.create<LLVM::InsertValueOp>(
loc, memRefTy, memRef, zero,
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(2)}));
// Get rank, sizes array ptr and strides array ptr.
auto rank = memRefTy.getStructElementType(3).getArrayNumElements();
auto sizesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {&dynMemRef});
auto stridesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {&dynMemRef});
for (decltype(rank) i = 0; i < rank; i++) {
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
loc, int64Ty, rewriter.getI64IntegerAttr(i));
// Insert size of the dimension.
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
loc, int64Ty.getPointerTo(), sizesArrayPtr,
ArrayRef<Value *>({dimIdx}));
auto dimSize = rewriter.create<LLVM::LoadOp>(loc, int64Ty.getPointerTo(),
dimSizePtr);
memRef = rewriter.create<LLVM::InsertValueOp>(
loc, memRefTy, memRef, dimSize,
rewriter.getArrayAttr(
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
// Insert stride of the dimension.
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
loc, int64Ty.getPointerTo(), sizesArrayPtr,
ArrayRef<Value *>({dimIdx}));
auto dimStride = rewriter.create<LLVM::LoadOp>(
loc, int64Ty.getPointerTo(), dimStridePtr);
memRef = rewriter.create<LLVM::InsertValueOp>(
loc, memRefTy, memRef, dimStride,
rewriter.getArrayAttr(
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
}
rewriter.create<LLVM::StoreOp>(loc, memRef, &ptrToMemRef);
}
void fillDynMemRefWithMemRef(Value &outMemRef, Value &outDynMemRef,
PatternRewriter &rewriter, const Location &loc,
const std::map<API, ApiSpec> &apiRegistry,
LLVM::LLVMDialect *llvmDialect) const {
auto outMemRefTy = outMemRef.getType().dyn_cast<LLVM::LLVMType>();
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
// Extract the data pointer, and record it in dynamic mem ref created.
Value *outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(
loc, outMemRefTy.getStructElementType(0), &outMemRef,
rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)}));
outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr);
callApi(rewriter, loc, apiRegistry, API::SET_DATA,
{&outDynMemRef, outMemRefDataPtr});
auto rank = outMemRefTy.getStructElementType(3).getArrayNumElements();
auto sizesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {&outDynMemRef});
auto stridesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {&outDynMemRef});
for (decltype(rank) i = 0; i < rank; i++) {
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
loc, int64Ty, rewriter.getI64IntegerAttr(i));
// Transfer size of dimension from memref to dynamic memref.
auto dimSize = rewriter.create<LLVM::ExtractValueOp>(
loc, int64Ty, &outMemRef,
rewriter.getArrayAttr(
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
loc, int64Ty.getPointerTo(), sizesArrayPtr,
ArrayRef<Value *>({dimIdx}));
rewriter.create<LLVM::StoreOp>(loc, dimSize, dimSizePtr);
// Transfer stride of dimension from memref to dynamic memref.
auto dimStride = rewriter.create<LLVM::ExtractValueOp>(
loc, int64Ty, &outMemRef,
rewriter.getArrayAttr(
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
loc, int64Ty.getPointerTo(), stridesArrayPtr,
ArrayRef<Value *>({dimIdx}));
rewriter.create<LLVM::StoreOp>(loc, dimStride, dimStridePtr);
}
}
};
} // end namespace
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// KRNL + Stadard + Affine dialects lowering to LLVM. // KRNL + Stadard + Affine dialects lowering to LLVM.
@ -109,7 +465,7 @@ namespace {
struct KrnlToLLVMLoweringPass : public ModulePass<KrnlToLLVMLoweringPass> { struct KrnlToLLVMLoweringPass : public ModulePass<KrnlToLLVMLoweringPass> {
void runOnModule() final; void runOnModule() final;
}; };
} // end anonymous namespace } // end anonymous namespace
void KrnlToLLVMLoweringPass::runOnModule() { void KrnlToLLVMLoweringPass::runOnModule() {
// Define the target for this lowering i.e. the LLVM dialect. // Define the target for this lowering i.e. the LLVM dialect.
@ -128,12 +484,13 @@ void KrnlToLLVMLoweringPass::runOnModule() {
populateStdToLLVMConversionPatterns(typeConverter, patterns); populateStdToLLVMConversionPatterns(typeConverter, patterns);
// Lower from the `krnl` dialect i.e. the Reshape operation. // Lower from the `krnl` dialect i.e. the Reshape operation.
patterns.insert<KrnlMemcpyOpLowering>(&getContext()); patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering>(
&getContext());
// We want to completely lower to LLVM, so we use a `FullConversion`. This // We want to completely lower to LLVM, so we use a `FullConversion`. This
// ensures that only legal operations will remain after the conversion. // ensures that only legal operations will remain after the conversion.
auto module = getModule(); if (failed(
if (failed(applyFullConversion(module, target, patterns, &typeConverter))) applyFullConversion(getModule(), target, patterns, &typeConverter)))
signalPassFailure(); signalPassFailure();
} }
@ -142,5 +499,5 @@ std::unique_ptr<mlir::Pass> mlir::createKrnlLowerToLLVMPass() {
return std::make_unique<KrnlToLLVMLoweringPass>(); return std::make_unique<KrnlToLLVMLoweringPass>();
} }
static PassRegistration<KrnlToLLVMLoweringPass> pass( static PassRegistration<KrnlToLLVMLoweringPass>
"lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM."); pass("lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM.");

View File

@ -7,10 +7,7 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include <cmath> #include <cmath>
#include <cstdlib>
#include <iostream> #include <iostream>
#include <random>
#include <tuple>
#include "llvm/Bitcode/BitcodeWriter.h" #include "llvm/Bitcode/BitcodeWriter.h"
#include "llvm/Support/CommandLine.h" #include "llvm/Support/CommandLine.h"
@ -24,9 +21,7 @@
#include "src/compiler/dialect/onnx/onnx_ops.hpp" #include "src/compiler/dialect/onnx/onnx_ops.hpp"
#include "src/compiler/pass/passes.hpp" #include "src/compiler/pass/passes.hpp"
#include "mlir/Analysis/Verifier.h"
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"

View File

@ -0,0 +1,17 @@
add_library(cruntime
dyn_memref.cpp
dyn_memref.h
data_type.h)
target_include_directories(cruntime
PRIVATE ${DLC_SRC_ROOT} ${DLC_BIN_ROOT}
${DLC_SRC_ROOT})
pybind11_add_module(pyruntime
dyn_memref.cpp
dyn_memref.h
runtime.cpp
runtime.hpp)
target_link_libraries(pyruntime PRIVATE ${CMAKE_DL_LIBS})
target_include_directories(pyruntime
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
${ONNF_SRC_ROOT})

30
src/runtime/data_type.h Normal file
View File

@ -0,0 +1,30 @@
enum DYN_MEMREF_DATA_TYPE {
UNDEFINED = 0;
// Basic types.
FLOAT = 1; // float
UINT8 = 2; // uint8_t
INT8 = 3; // int8_t
UINT16 = 4; // uint16_t
INT16 = 5; // int16_t
INT32 = 6; // int32_t
INT64 = 7; // int64_t
STRING = 8; // string
BOOL = 9; // bool
// IEEE754 half-precision floating-point format (16 bits wide).
// This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
FLOAT16 = 10;
DOUBLE = 11;
UINT32 = 12;
UINT64 = 13;
COMPLEX64 = 14; // complex with float32 real and imaginary components
COMPLEX128 = 15; // complex with float64 real and imaginary components
// Non-IEEE floating-point format based on IEEE754 single-precision
// floating-point number truncated to 16 bits.
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
BFLOAT16 = 16;
// Future extensions go here.
};

View File

@ -0,0 +1,74 @@
#include <cassert>
#include <map>
#include <string>
#include <vector>
#include "dyn_memref.h"
DynMemRef::DynMemRef(int _rank) {
rank = _rank;
sizes = (INDEX_TYPE *)malloc(rank * sizeof(INDEX_TYPE));
strides = (int64_t *)malloc(rank * sizeof(int64_t));
}
// An ordered dynamic MemRef dictionary.
// The goal is to support accessing dynamic memory ref by name and by index.
// Currently, only accessing by index is supported.
struct OrderedDynMemRefDict {
std::map<std::string, DynMemRef *> tensorDict;
std::vector<std::string> orderedNames;
};
int numDynMemRefs(OrderedDynMemRefDict *dict) {
return dict->orderedNames.size();
}
OrderedDynMemRefDict *createOrderedDynMemRefDict() {
return new OrderedDynMemRefDict();
}
DynMemRef *createDynMemRef(int rank) { return new DynMemRef(rank); }
DynMemRef *getDynMemRef(OrderedDynMemRefDict *tensorDict, int idx) {
return tensorDict->tensorDict[tensorDict->orderedNames[idx]];
}
void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx,
DynMemRef *tensor) {
if (tensorDict->orderedNames.capacity() <= idx)
tensorDict->orderedNames.resize(idx + 1);
// The dynamic memref is essentially anonymous, since we are storing it by
// indexed position.
// TODO: can use random string as names to reduce chance of collision.
auto unique_name = std::to_string(idx);
assert(tensorDict->tensorDict.count(unique_name) == 0 &&
"duplicate dynamic mem ref name");
tensorDict->orderedNames[idx] = unique_name;
tensorDict->tensorDict[tensorDict->orderedNames[idx]] = tensor;
}
void *getData(DynMemRef *dynMemRef) { return dynMemRef->data; }
void setData(DynMemRef *dynMemRef, void *dataPtr) { dynMemRef->data = dataPtr; }
void *getAlignedData(DynMemRef *dynMemRef) { return dynMemRef->alignedData; }
void setAlignedData(DynMemRef *dynMemRef, void *dataPtr) {
dynMemRef->alignedData = dataPtr;
}
INDEX_TYPE *getSizes(DynMemRef *dynMemRef) { return dynMemRef->sizes; }
void setSizes(DynMemRef *dynMemRef, INDEX_TYPE *sizes) {
for (int i = 0; i < dynMemRef->rank; i++)
dynMemRef->sizes[i] = sizes[i];
}
int64_t *getStrides(DynMemRef *dynMemRef) { return dynMemRef->strides; }
void setStrides(DynMemRef *dynMemRef, int64_t *strides) {
for (int i = 0; i < dynMemRef->rank; i++)
dynMemRef->sizes[i] = strides[i];
}

61
src/runtime/dyn_memref.h Normal file
View File

@ -0,0 +1,61 @@
#pragma once
#include <cstdint>
typedef int64_t INDEX_TYPE;
// This is a dynamic version of memref.
// The same struct can be used to represent memrefs of
// all ranks and type combinations.
struct DynMemRef {
void *data;
void *alignedData;
INDEX_TYPE offset;
unsigned int rank;
INDEX_TYPE *sizes;
int64_t *strides;
DynMemRef(int _rank);
};
extern "C" {
// Ordered DynMemRef Dictionary is a data structure for wrapping the input
// dynmemrefs so that they can be addressed both by index and by name.
struct OrderedDynMemRefDict;
// Get number of dynamic memrefs in OrderedDynMemRefDict dict.
int numDynMemRefs(OrderedDynMemRefDict *dict);
// Create an ordered dynmemref dictionary.
OrderedDynMemRefDict *createOrderedDynMemRefDict();
// Create a dynmemref with a certain rank.
DynMemRef *createDynMemRef(int rank);
// Get the i-th dynmemref from orderedDict.
DynMemRef *getDynMemRef(OrderedDynMemRefDict *orderedDict, int i);
// Set the i-th dynmemref in orderedDict to be dynMemRef.
void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx,
DynMemRef *dynMemRef);
// Get data pointer from dynMemRef.
void *getData(DynMemRef *dynMemRef);
// Set data pointer for dynMemRef.
void setData(DynMemRef *dynMemRef, void *data);
// Get algined data pointer from dynMemRef.
void *getAlignedData(DynMemRef *);
// Set aligned data pointer for dynMemRef.
void setAlignedData(DynMemRef *, void *);
// Get ptr to sizes array.
INDEX_TYPE *getSizes(DynMemRef *);
// Get ptr to strides array.
int64_t *getStrides(DynMemRef *);
}

52
src/runtime/runtime.cpp Normal file
View File

@ -0,0 +1,52 @@
#include "runtime.hpp"
ExecutionSession::ExecutionSession(std::string sharedLibPath,
std::string entryPointName) {
_sharedLibraryHandle = dlopen(sharedLibPath.c_str(), RTLD_LAZY);
_entryPointFunc =
(entryPointFuncType)dlsym(_sharedLibraryHandle, entryPointName.c_str());
}
std::vector<py::array>
ExecutionSession::run(std::vector<py::array> inputsPyArray) {
assert(_entryPointFunc && "entry point not loaded");
auto *wrappedInput = createOrderedDynMemRefDict();
int inputIdx = 0;
for (auto inputPyArray : inputsPyArray) {
auto *inputDynMemRef = createDynMemRef(inputPyArray.ndim());
assert(inputPyArray.flags() && py::array::c_style &&
"expect contiguous python array");
if (inputPyArray.writeable()) {
inputDynMemRef->data = inputPyArray.mutable_data();
inputDynMemRef->alignedData = inputPyArray.mutable_data();
} else {
// If data is not writable, copy them to a writable buffer.
auto *copiedData = (float *)malloc(inputPyArray.nbytes());
memcpy(copiedData, inputPyArray.data(), inputPyArray.nbytes());
inputDynMemRef->data = copiedData;
inputDynMemRef->alignedData = copiedData;
}
for (int i = 0; i < inputPyArray.ndim(); i++) {
inputDynMemRef->sizes[i] = inputPyArray.shape(i);
inputDynMemRef->strides[i] = inputPyArray.strides(i);
}
setDynMemRef(wrappedInput, inputIdx++, inputDynMemRef);
}
std::vector<py::array> outputPyArrays;
auto *wrappedOutput = _entryPointFunc(wrappedInput);
for (int i = 0; i < numDynMemRefs(wrappedOutput); i++) {
auto *dynMemRef = getDynMemRef(wrappedOutput, i);
auto shape = std::vector<int64_t>(dynMemRef->sizes,
dynMemRef->sizes + dynMemRef->rank);
outputPyArrays.emplace_back(
py::array(py::dtype("float32"), shape, dynMemRef->data));
}
return outputPyArrays;
}
ExecutionSession::~ExecutionSession() { dlclose(_sharedLibraryHandle); }

37
src/runtime/runtime.hpp Normal file
View File

@ -0,0 +1,37 @@
#pragma once
#include <cassert>
#include <string>
#include <dlfcn.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "src/runtime/dyn_memref.h"
namespace py = pybind11;
typedef OrderedDynMemRefDict *(*entryPointFuncType)(OrderedDynMemRefDict *);
class ExecutionSession {
public:
ExecutionSession(std::string sharedLibPath, std::string entryPointName);
std::vector<py::array> run(std::vector<py::array> inputsPyArray);
~ExecutionSession();
private:
// Handler to the shared library file being loaded.
void *_sharedLibraryHandle = nullptr;
// Entry point function.
entryPointFuncType _entryPointFunc = nullptr;
};
PYBIND11_MODULE(pyruntime, m) {
py::class_<ExecutionSession>(m, "ExecutionSession")
.def(py::init<const std::string &, const std::string &>())
.def("run", &ExecutionSession::run);
}

162
test/onnx_backend_test.py Normal file
View File

@ -0,0 +1,162 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import itertools
import os
import unittest
import onnx.backend.base
import onnx.backend.test
from onnx.backend.base import Device, DeviceType
import onnx.shape_inference
import onnx.version_converter
from typing import Optional, Text, Any, Tuple, Sequence
from onnx import NodeProto, ModelProto, TensorProto
import numpy # type: ignore
import subprocess
from pyruntime import ExecutionSession
CXX = os.getenv('CXX')
FE = os.getenv('FE')
LLC = os.getenv('LLC')
RT_DIR = os.getenv('RT_DIR')
assert CXX and FE and LLC and RT_DIR, "tools path not set"
class DummyBackend(onnx.backend.base.Backend):
@classmethod
def prepare(
cls,
model,
device='CPU',
**kwargs
):
super(DummyBackend, cls).prepare(model, device, **kwargs)
# Save model to disk as temp_model.onnx.
onnx.save(model, "temp_model.onnx")
# Call frontend to process temp_model.onnx, bit code will be generated.
subprocess.run([FE, "temp_model.onnx"], stdout=subprocess.PIPE)
# Call llc to generate object file from bitcode.
subprocess.run([LLC, "-filetype=obj", "model.bc"],
stdout=subprocess.PIPE)
# Generate shared library from object file, linking with c runtime.
subprocess.run([
CXX, "-shared", "model.o", "-o", "model.so", "-L" + RT_DIR,
"-lcruntime"
],
stdout=subprocess.PIPE)
return ExecutionSession("./model.so", "_dyn_entry_point_main_graph")
@classmethod
def supports_device(cls, device):
d = Device(device)
if d.type == DeviceType.CPU:
return True
return False
backend_test = onnx.backend.test.BackendTest(DummyBackend, __name__)
# Test directories:
# https://github.com/onnx/onnx/tree/master/onnx/backend/test/data/node
test_to_enable = [
# Add Op:
"test_add_cpu",
"test_add_bcast_cpu",
# And Op:
# Sub Op:
"test_sub_cpu",
"test_sub_bcast_cpu",
"test_sub_example_cpu",
# Cosh Op:
"test_cosh_cpu",
"test_cosh_example_cpu",
# Div Op:
"test_div_cpu",
"test_div_bcast_cpu",
"test_div_example_cpu",
# Elu Op:
"test_elu_cpu",
"test_elu_default_cpu",
"test_elu_example_cpu",
# Exp Op:
"test_exp_cpu",
"test_exp_example_cpu",
# Hard Sigmoid Op:
"test_hardsigmoid_cpu",
"test_hardsigmoid_default_cpu",
"test_hardsigmoid_example_cpu",
# Leaky Relu Op:
"test_leakyrelu_cpu",
"test_leakyrelu_default_cpu",
"test_leakyrelu_example_cpu",
# Max Op:
# "test_max_example_cpu", <- error
"test_max_one_input_cpu",
# "test_max_two_inputs_cpu", <- error
# Min Op:
# "test_min_example_cpu", <- error
"test_min_one_input_cpu",
# "test_min_two_inputs_cpu", <- error
# Mul Op:
"test_mul_cpu",
"test_mul_bcast_cpu",
"test_mul_example_cpu",
# Relu Op:
"test_relu_cpu",
# Selu Op:
"test_selu_cpu",
"test_selu_default_cpu",
"test_selu_example_cpu",
# Sigmoid Op:
"test_sigmoid_cpu",
"test_sigmoid_example_cpu",
# Sum Op:
#"test_sum_example_cpu", <- error
"test_sum_one_input_cpu",
#"test_sum_two_inputs_cpu", <- error
# Reciprocal Op:
#"test_reciprocal_cpu", <- error on shape inference.
#"test_reciprocal_example_cpu", <- error on shape inference.
]
# Extract name of all test cases.
import inspect
all_tests = inspect.getmembers(
backend_test.test_cases["OnnxBackendNodeModelTest"])
all_test_names = list(map(lambda x: x[0], all_tests))
for test_name in test_to_enable:
assert test_name in all_test_names, "test name {} not found".format(test_name)
backend_test.include(r"^{}$".format(test_name))
def tearDownModule():
print()
print("*" * 40)
print("A total of {} tests should have run".format(len(test_to_enable)))
print("*" * 40)
# import all test cases at global scope to make them visible to python.unittest
globals().update(backend_test.test_cases)
if __name__ == '__main__':
unittest.main()