Enable ONNX Backend Test (#1)
* wip, commit before merging with upstream * organize API, return wrapped output * enable onnx backend test * undo unintentional commit * fix krnl ops tablegen * format krnl ops * reorder fillDynMemRefWithMemRef to be after fillPtrToMemRefWithDynMemRef, better comments * more onnx backend tests * ensure that test names refer to existing tests * improve code readability by shortening type names * nit * restore unintentional changes * more nits * fix ; -> : * split runtime implementation into header and body file, add support for data types * comment on the onnx backend test * make the comments read better * do not dump when lowering
This commit is contained in:
parent
5573cb39fe
commit
685bf23b40
|
@ -26,6 +26,7 @@ add_subdirectory(third_party/pybind11)
|
|||
set(CMAKE_CXX_STANDARD 14)
|
||||
add_subdirectory(src/builder)
|
||||
add_subdirectory(src/compiler)
|
||||
add_subdirectory(src/runtime)
|
||||
add_subdirectory(src)
|
||||
|
||||
add_subdirectory(test)
|
||||
|
|
|
@ -34,7 +34,7 @@ set(MLIR_SRC_INCLUDE_PATH ${LLVM_SRC}/projects/mlir/include)
|
|||
set(MLIR_BIN_INCLUDE_PATH ${LLVM_BUILD}/projects/mlir/include)
|
||||
set(MLIR_TOOLS_DIR ${LLVM_BUILD}/bin)
|
||||
|
||||
set(ONNF_TOOLS_DIR ${ONNF_BIN_ROOT}/src/compiler/tool/onnf_opt)
|
||||
set(ONNF_TOOLS_DIR ${ONNF_BIN_ROOT}/bin)
|
||||
set(ONNF_LIT_TEST_SRC_DIR ${CMAKE_SOURCE_DIR}/test/mlir)
|
||||
set(ONNF_LIT_TEST_BUILD_DIR ${CMAKE_BINARY_DIR}/test/mlir)
|
||||
|
||||
|
|
|
@ -614,7 +614,6 @@ private:
|
|||
onnx::NodeProto node, int nIn, int nOut,
|
||||
std::initializer_list<std::tuple<std::string, std::string, std::string>>
|
||||
attrs) {
|
||||
|
||||
// Conv has attribute dilations, kernel_shape, pads, the default value of
|
||||
// which is determined by the shape of first argument. However, since the
|
||||
// shape is unknown now, these attributes can be not generated auto
|
||||
|
@ -686,7 +685,7 @@ private:
|
|||
}
|
||||
|
||||
void ImportGraph(const onnx::GraphProto &graph,
|
||||
const std::string &name = "main") {
|
||||
const std::string &name = "main_graph") {
|
||||
// create a function for the graph
|
||||
// TODO:
|
||||
// * get name and type for the function.
|
||||
|
@ -699,13 +698,18 @@ private:
|
|||
}
|
||||
|
||||
// TODO: import the initializer
|
||||
auto func_type = builder_.getFunctionType(arg_types, {});
|
||||
auto main_func =
|
||||
mlir::FuncOp::create(UnknownLoc(), name, func_type, /* attrs = */ {});
|
||||
auto &entryBlock = *main_func.addEntryBlock();
|
||||
auto funcType = builder_.getFunctionType(arg_types, {});
|
||||
auto mainFunc =
|
||||
mlir::FuncOp::create(UnknownLoc(), name, funcType, /* attrs = */ {});
|
||||
auto entryPoint = mlir::ONNXEntryPointOp::create(
|
||||
UnknownLoc(), mainFunc, /*numInputs=*/graph.input().size(),
|
||||
/*numOutputs=*/graph.output().size());
|
||||
|
||||
auto &entryBlock = *mainFunc.addEntryBlock();
|
||||
builder_.setInsertionPointToStart(&entryBlock);
|
||||
module_.push_back(main_func);
|
||||
|
||||
module_.push_back(mainFunc);
|
||||
module_.push_back(entryPoint);
|
||||
|
||||
for (auto it : llvm::zip(graph.input(), entryBlock.getArguments())) {
|
||||
ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it));
|
||||
|
@ -728,8 +732,8 @@ private:
|
|||
builder_.create<mlir::ReturnOp>(UnknownLoc(), ret_vals);
|
||||
// Update main function signature to reflect types of newly imported
|
||||
// output tensors.
|
||||
func_type = builder_.getFunctionType(arg_types, ret_types);
|
||||
main_func.setType(func_type);
|
||||
funcType = builder_.getFunctionType(arg_types, ret_types);
|
||||
mainFunc.setType(funcType);
|
||||
}
|
||||
}; // FrontendGenImpl class
|
||||
} // namespace
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
namespace mlir {
|
||||
class MLIRContext;
|
||||
class OwningModuleRef;
|
||||
} // namespace mlir
|
||||
} // namespace mlir
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Import a model into one of ONNF's frontend models.
|
||||
|
@ -41,7 +41,8 @@ mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model);
|
|||
* @return MLIR::module generated for the ONNX model.
|
||||
*/
|
||||
void ImportFrontendModelFile(std::string model_fname,
|
||||
mlir::MLIRContext& context, mlir::OwningModuleRef& module);
|
||||
mlir::MLIRContext &context,
|
||||
mlir::OwningModuleRef &module);
|
||||
|
||||
/*!
|
||||
* TODO: Import models into other extension dialects that cover the
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
add_library(DLCAnalysis STATIC
|
||||
extract_integer_set.cpp)
|
||||
|
||||
target_include_directories(DLCAnalysis PRIVATE ${DLC_SRC_ROOT})
|
||||
target_include_directories(DLCAnalysis PRIVATE ${DLC_BIN_ROOT})
|
|
@ -364,6 +364,14 @@ ParseResult parseKrnlReturnLoopsOp(OpAsmParser &parser,
|
|||
return success();
|
||||
}
|
||||
|
||||
void KrnlEntryPointOp::build(mlir::Builder *builder, OperationState &state,
|
||||
SymbolRefAttr funcAttr, IntegerAttr numInputs,
|
||||
IntegerAttr numOutputs) {
|
||||
state.addAttribute(KrnlEntryPointOp::getEntryPointFuncAttrName(), funcAttr);
|
||||
state.addAttribute(KrnlEntryPointOp::getNumInputsAttrName(), numInputs);
|
||||
state.addAttribute(KrnlEntryPointOp::getNumOutputsAttrName(), numOutputs);
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "src/compiler/krnl.cpp.inc"
|
||||
} // namespace mlir
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
|
|
|
@ -8,35 +8,27 @@
|
|||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
|
||||
def Krnl_Dialect : Dialect {
|
||||
let name = "krnl";
|
||||
let cppNamespace = "";
|
||||
}
|
||||
|
||||
// Require regions to have krnl.terminate terminator operation.
|
||||
def ImplicitKrnlTerminator
|
||||
: SingleBlockImplicitTerminator<"KrnlTerminatorOp">;
|
||||
def ImplicitKrnlTerminator : SingleBlockImplicitTerminator<"KrnlTerminatorOp">;
|
||||
|
||||
def KrnlDefineLoopsOp : Op<Krnl_Dialect, "define_loops"> {
|
||||
let summary = "define_loops operation";
|
||||
let description = [{
|
||||
|
||||
The "krnl.define_loops" operation is used to define input loops,
|
||||
those are the for loops appearing in the input program that we
|
||||
intend to optimize.
|
||||
|
||||
}];
|
||||
|
||||
let arguments = (ins);
|
||||
let results = (outs Variadic<AnyType>);
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
|
||||
let builders = [
|
||||
OpBuilder<"Builder *builder, OperationState &result,"
|
||||
"int64_t num_loops">
|
||||
];
|
||||
let builders = [ OpBuilder<"Builder *builder, OperationState &result,"
|
||||
"int64_t num_loops"> ];
|
||||
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
|
@ -44,33 +36,27 @@ def KrnlDefineLoopsOp : Op<Krnl_Dialect, "define_loops"> {
|
|||
let extraClassDeclaration = [{
|
||||
static StringRef getNumLoopsAttrName() { return "num_loops"; }
|
||||
|
||||
// Helper function to extract the number of loops being defined.
|
||||
int64_t getNumLoops() {
|
||||
auto num_loops =
|
||||
getAttrOfType<IntegerAttr>(
|
||||
getNumLoopsAttrName())
|
||||
.getValue()
|
||||
.getSExtValue();
|
||||
return num_loops;
|
||||
}
|
||||
}];
|
||||
|
||||
|
||||
// Helper function to extract the number of loops being defined.
|
||||
int64_t getNumLoops() {
|
||||
auto num_loops = getAttrOfType<IntegerAttr>(getNumLoopsAttrName())
|
||||
.getValue()
|
||||
.getSExtValue();
|
||||
return num_loops;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def KrnlOptimizeLoopsOp : Op<Krnl_Dialect, "optimize_loops"> {
|
||||
let summary = "optimize_loops operation";
|
||||
let description = [{
|
||||
|
||||
The "krnl.optimize_loops" operation is essentially a cosmetic operation
|
||||
which exists to encapsulate a region where loops are being scheduled/optimized.
|
||||
which exists to encapsulate a region where loops are being scheduled /
|
||||
optimized.
|
||||
|
||||
The optimized loops are returned at the end of the
|
||||
region associated with the krnl.optimize_loops operation.
|
||||
|
||||
For example:
|
||||
TBD once we have actual schedule intrinsics.
|
||||
The optimized loops are returned at the end of the region associated with
|
||||
the krnl.optimize_loops operation.
|
||||
|
||||
For example : TBD once we have actual schedule intrinsics.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyType>);
|
||||
|
@ -79,10 +65,8 @@ def KrnlOptimizeLoopsOp : Op<Krnl_Dialect, "optimize_loops"> {
|
|||
|
||||
let skipDefaultBuilders = 1;
|
||||
|
||||
let builders = [
|
||||
OpBuilder<"Builder *builder, OperationState &result, "
|
||||
"int timestamp_space_rank">
|
||||
];
|
||||
let builders = [ OpBuilder<"Builder *builder, OperationState &result, "
|
||||
"int timestamp_space_rank"> ];
|
||||
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
|
@ -91,7 +75,6 @@ def KrnlOptimizeLoopsOp : Op<Krnl_Dialect, "optimize_loops"> {
|
|||
def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> {
|
||||
let summary = "iterate operation";
|
||||
let description = [{
|
||||
|
||||
The "krnl.iterate" operation is conceptually equivalent to a nested for loops.
|
||||
|
||||
For instance, say we have the following two
|
||||
|
@ -103,25 +86,20 @@ def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> {
|
|||
|
||||
Then, consider the following krnl.iterate operation:
|
||||
krnl.iterate (%o0, %o1) with (%l0 -> %i0 = 0 to 10, %l1 -> %i1 = 0 to 10) {
|
||||
// Some operations.
|
||||
// Some operations.
|
||||
}
|
||||
|
||||
It is equivalent to:
|
||||
for (i0=0; i0<10; i0++)
|
||||
for (i1=0; i1<10; i1++)
|
||||
for (i0 = 0; i0 < 10; i0++)
|
||||
for (i1 = 0; i1 < 10; i1++)
|
||||
// Some operations.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyType>);
|
||||
|
||||
let regions = (region SizedRegion<1>:$bodyRegion);
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
|
||||
let builders = [
|
||||
OpBuilder<"Builder *builder, OperationState &result, "
|
||||
"KrnlIterateOperandPack operandPack">
|
||||
];
|
||||
let builders = [ OpBuilder<"Builder *builder, OperationState &result, "
|
||||
"KrnlIterateOperandPack operandPack"> ];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// In krnl.iterate operation, operands are stored as such
|
||||
|
@ -134,20 +112,19 @@ def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> {
|
|||
|
||||
int64_t getNumOptimizedLoops() {
|
||||
auto num_optimized_loops =
|
||||
getAttrOfType<IntegerAttr>(
|
||||
getNumOptimizedLoopsAttrName())
|
||||
.getValue()
|
||||
.getSExtValue();
|
||||
getAttrOfType<IntegerAttr>(getNumOptimizedLoopsAttrName())
|
||||
.getValue()
|
||||
.getSExtValue();
|
||||
return num_optimized_loops;
|
||||
}
|
||||
|
||||
// Get name of the attribute for storing bound represented using affine maps.
|
||||
static StringRef getBoundsAttrName() { return "bounds"; }
|
||||
static StringRef getBoundsAttrName() { return "bounds"; }
|
||||
}];
|
||||
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def KrnlReturnLoopsOp : Op<Krnl_Dialect, "return_loops", [Terminator]> {
|
||||
|
@ -182,11 +159,30 @@ def KrnlTerminatorOp : Op<Krnl_Dialect, "terminate", [Terminator]> {
|
|||
let verifier = ?;
|
||||
}
|
||||
|
||||
def KrnlEntryPointOp : Op<Krnl_Dialect, "entry_point"> {
|
||||
let summary = "Indicate ONNX entry point";
|
||||
let description = [{The "krnl.entry_point" function indicates the main entry
|
||||
point of ONNX model.}];
|
||||
let builders = [ OpBuilder<"Builder *builder, OperationState &result, "
|
||||
"SymbolRefAttr funcAttr, IntegerAttr numInputs, "
|
||||
"IntegerAttr numOutputs"> ];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static StringRef getEntryPointFuncAttrName() { return "func"; }
|
||||
static StringRef getNumInputsAttrName() { return "numInputs"; }
|
||||
static StringRef getNumOutputsAttrName() { return "numOutputs"; }
|
||||
}];
|
||||
|
||||
// No custom parsing/printing form.
|
||||
let parser = ?;
|
||||
let printer = ?;
|
||||
}
|
||||
|
||||
def KrnlMemcpyOp : Op<Krnl_Dialect, "memcpy"> {
|
||||
let summary = "Krnl memcpy operation";
|
||||
let description = [{
|
||||
In the KRNL dialect the reshape op doesn't generate a new memory entry and
|
||||
treats a reshape like a cast.
|
||||
In the KRNL dialect the reshape op
|
||||
doesn't generate a new memory entry and treats a reshape like a cast.
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyMemRef:$dest, AnyMemRef:$src, AnyInteger:$size);
|
||||
|
|
|
@ -58,6 +58,27 @@ class ONNX_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||
|
||||
include "dialect/onnx/onnxop.inc"
|
||||
|
||||
// Indicate entry point functions of ONNX graph.
|
||||
def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
|
||||
let summary = "Indicate ONNX entry point";
|
||||
let description = [{
|
||||
The "onnx.EntryPoint" function indicates the main entry point of ONNX model.
|
||||
}];
|
||||
|
||||
let builders = [OpBuilder<[{Builder *builder, OperationState &state,
|
||||
FuncOp function, int numInputs, int numOutputs}]>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static ONNXEntryPointOp create(Location location, FuncOp& func,
|
||||
int numInputs, int numOutputs);
|
||||
|
||||
static StringRef getEntryPointFuncAttrName() { return "func"; }
|
||||
static StringRef getNumInputsAttrName() { return "numInputs"; }
|
||||
static StringRef getNumOutputsAttrName() { return "numOutputs"; }
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
def ONNXFullGemmOp: ONNX_Op<"FullGemm",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX general matrix multiply operation";
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
// This file defines ONNX operations in the MLIR operation set.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
@ -38,6 +37,28 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx)
|
|||
>();
|
||||
}
|
||||
|
||||
void ONNXEntryPointOp::build(mlir::Builder *builder,
|
||||
mlir::OperationState &state, mlir::FuncOp function,
|
||||
int numInputs, int numOutputs) {
|
||||
state.addAttribute(ONNXEntryPointOp::getEntryPointFuncAttrName(),
|
||||
builder->getSymbolRefAttr(function));
|
||||
state.addAttribute(ONNXEntryPointOp::getNumInputsAttrName(),
|
||||
builder->getI32IntegerAttr(numInputs));
|
||||
state.addAttribute(ONNXEntryPointOp::getNumOutputsAttrName(),
|
||||
builder->getI32IntegerAttr(numOutputs));
|
||||
}
|
||||
|
||||
ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location,
|
||||
mlir::FuncOp &func, int numInputs,
|
||||
int numOutputs) {
|
||||
mlir::OperationState state(location, "onnx.EntryPoint");
|
||||
Builder builder(location->getContext());
|
||||
mlir::ONNXEntryPointOp::build(&builder, state, func, numInputs, numOutputs);
|
||||
Operation *op = mlir::Operation::create(state);
|
||||
auto onnxEntryOp = llvm::cast<mlir::ONNXEntryPointOp>(op);
|
||||
return onnxEntryOp;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNX Operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
|
|
@ -8,7 +8,6 @@
|
|||
// Krnl IR and standard operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||
|
@ -884,6 +883,27 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
|||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// EntryPoint Op lowering to Krnl Entry Point.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class ONNXEntryPointLowering : public OpRewritePattern<ONNXEntryPointOp> {
|
||||
public:
|
||||
using OpRewritePattern<ONNXEntryPointOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(ONNXEntryPointOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<KrnlEntryPointOp>(
|
||||
op,
|
||||
op.getAttrOfType<SymbolRefAttr>(
|
||||
ONNXEntryPointOp::getEntryPointFuncAttrName()),
|
||||
op.getAttrOfType<IntegerAttr>(ONNXEntryPointOp::getNumInputsAttrName()),
|
||||
op.getAttrOfType<IntegerAttr>(
|
||||
ONNXEntryPointOp::getNumOutputsAttrName()));
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conversion from Tensor type to the Standard dialect MemRef type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -985,7 +1005,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
|||
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
|
||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
|
||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
|
||||
ONNXReshapeOpLowering>(&getContext());
|
||||
ONNXReshapeOpLowering, ONNXEntryPointLowering>(&getContext());
|
||||
|
||||
// With the target and rewrite patterns defined, we can now attempt the
|
||||
// conversion. The conversion will signal failure if any of our `illegal`
|
||||
|
|
|
@ -143,14 +143,16 @@ void KrnlToAffineLoweringPass::runOnFunction() {
|
|||
// We expect IR to be free of Krnl Dialect Ops.
|
||||
target.addIllegalDialect<KrnlOpsDialect>();
|
||||
target.addLegalOp<KrnlMemcpyOp>();
|
||||
target.addLegalOp<KrnlEntryPointOp>();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
|
||||
KrnlDefineLoopsLowering, KrnlOptimizeLoopsLowering>(
|
||||
&getContext());
|
||||
|
||||
if (failed(applyPartialConversion(getFunction(), target, patterns)))
|
||||
if (failed(applyPartialConversion(getFunction(), target, patterns))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
||||
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
|
@ -15,6 +14,7 @@
|
|||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
|
||||
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
||||
#include "src/compiler/pass/passes.hpp"
|
||||
|
@ -23,20 +23,39 @@ using namespace mlir;
|
|||
|
||||
namespace {
|
||||
|
||||
static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName,
|
||||
ModuleOp module,
|
||||
mlir::LLVM::LLVMType funcType,
|
||||
PatternRewriter &rewriter) {
|
||||
auto *context = module.getContext();
|
||||
if (module.lookupSymbol<LLVM::LLVMFuncOp>(funcName)) {
|
||||
auto symbolRef = SymbolRefAttr::get(funcName, context);
|
||||
assert(symbolRef.getType() == funcType && "wrong symbol type");
|
||||
return symbolRef;
|
||||
}
|
||||
|
||||
// Insert the function into the body of the parent module.
|
||||
PatternRewriter::InsertionGuard insertGuard(rewriter);
|
||||
rewriter.setInsertionPointToStart(module.getBody());
|
||||
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), funcName, funcType);
|
||||
return SymbolRefAttr::get(funcName, context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// KRNL to LLVM: patterns which need a direct lowering to LLVM.
|
||||
// KRNL to LLVM: KrnlMemcpyOpLowering
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class KrnlMemcpyOpLowering : public ConversionPattern {
|
||||
public:
|
||||
explicit KrnlMemcpyOpLowering(MLIRContext* context)
|
||||
public:
|
||||
explicit KrnlMemcpyOpLowering(MLIRContext *context)
|
||||
: ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
auto* context = op->getContext();
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto *context = op->getContext();
|
||||
auto loc = op->getLoc();
|
||||
auto* llvmDialect =
|
||||
auto *llvmDialect =
|
||||
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
assert(llvmDialect && "expected llvm dialect to be registered");
|
||||
|
||||
|
@ -47,39 +66,40 @@ class KrnlMemcpyOpLowering : public ConversionPattern {
|
|||
// First operand.
|
||||
Type dstType =
|
||||
operands[0]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
|
||||
Value* alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
|
||||
Value *alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, dstType, operands[0], rewriter.getI64ArrayAttr(1));
|
||||
Value* alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
|
||||
Value *alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory);
|
||||
|
||||
// Second operand.
|
||||
Type srcType =
|
||||
operands[1]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
|
||||
Value* alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
|
||||
Value *alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, srcType, operands[1], rewriter.getI64ArrayAttr(1));
|
||||
Value* alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
|
||||
Value *alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory);
|
||||
|
||||
// Size.
|
||||
Value* int64Size = rewriter.create<LLVM::SExtOp>(
|
||||
Value *int64Size = rewriter.create<LLVM::SExtOp>(
|
||||
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]);
|
||||
|
||||
// Memcpy call
|
||||
rewriter.create<CallOp>(loc, memcpyRef,
|
||||
LLVM::LLVMType::getVoidTy(llvmDialect),
|
||||
ArrayRef<Value*>(
|
||||
rewriter.create<CallOp>(
|
||||
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
|
||||
ArrayRef<Value *>(
|
||||
{alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size}));
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
/// Return a symbol reference to the memcpy function, inserting it into the
|
||||
/// module if necessary.
|
||||
static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter& rewriter,
|
||||
ModuleOp module, LLVM::LLVMDialect* llvmDialect) {
|
||||
auto* context = module.getContext();
|
||||
static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter,
|
||||
ModuleOp module,
|
||||
LLVM::LLVMDialect *llvmDialect) {
|
||||
auto *context = module.getContext();
|
||||
if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64"))
|
||||
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
||||
// Create a function declaration for memcpy, the signature is:
|
||||
|
@ -87,19 +107,355 @@ class KrnlMemcpyOpLowering : public ConversionPattern {
|
|||
auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect);
|
||||
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
||||
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy,
|
||||
auto llvmFnType = LLVM::LLVMType::getFunctionTy(
|
||||
llvmVoidTy,
|
||||
ArrayRef<mlir::LLVM::LLVMType>({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty}),
|
||||
false);
|
||||
|
||||
// Insert the memcpy function into the body of the parent module.
|
||||
PatternRewriter::InsertionGuard insertGuard(rewriter);
|
||||
rewriter.setInsertionPointToStart(module.getBody());
|
||||
rewriter.create<LLVM::LLVMFuncOp>(
|
||||
module.getLoc(), "llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
|
||||
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(),
|
||||
"llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
|
||||
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
||||
}
|
||||
};
|
||||
} // end namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// KRNL to LLVM: KrnlEntryPointOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class KrnlEntryPointOpLowering : public OpRewritePattern<KrnlEntryPointOp> {
|
||||
public:
|
||||
using OpRewritePattern<KrnlEntryPointOp>::OpRewritePattern;
|
||||
|
||||
enum class API {
|
||||
CREATE_ORDERED_DYN_MEM_REF_DICT,
|
||||
CREATE_DYN_MEM_REF,
|
||||
GET_DYN_MEM_REF,
|
||||
SET_DYN_MEM_REF,
|
||||
GET_DATA,
|
||||
SET_DATA,
|
||||
GET_SIZES,
|
||||
GET_STRIDES,
|
||||
};
|
||||
|
||||
struct ApiSpec {
|
||||
API id;
|
||||
std::string name;
|
||||
FlatSymbolRefAttr symbolRef;
|
||||
LLVM::LLVMType outputTy;
|
||||
SmallVector<LLVM::LLVMType, 4> inputTys;
|
||||
|
||||
ApiSpec(API id, const std::string &name, LLVM::LLVMType outputTy,
|
||||
ArrayRef<LLVM::LLVMType> inputTys)
|
||||
: id(id), name(name), outputTy(outputTy),
|
||||
inputTys(inputTys.begin(), inputTys.end()) {}
|
||||
|
||||
LLVM::LLVMType funcTy() {
|
||||
return LLVM::LLVMType::getFunctionTy(outputTy, inputTys,
|
||||
/*isVarArg=*/false);
|
||||
}
|
||||
};
|
||||
|
||||
PatternMatchResult matchAndRewrite(KrnlEntryPointOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
auto *llvmDialect =
|
||||
op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
assert(llvmDialect && "expected llvm dialect to be registered");
|
||||
auto module = op.getParentOfType<ModuleOp>();
|
||||
auto apiRegistry = RegisterAllApis(module, rewriter, llvmDialect);
|
||||
auto loc = op.getLoc();
|
||||
auto numOutputs =
|
||||
op.getAttrOfType<IntegerAttr>(KrnlEntryPointOp::getNumOutputsAttrName())
|
||||
.getInt();
|
||||
|
||||
using LLVMType = LLVM::LLVMType;
|
||||
auto opaquePtrTy = LLVMType::getInt8PtrTy(llvmDialect);
|
||||
auto int32Ty = LLVMType::getInt32Ty(llvmDialect);
|
||||
|
||||
// Rewrite Krnl Entry Point Operation to an LLVM function with a dynamic
|
||||
// signature. The signature is dynamic because it remains the same no matter
|
||||
// what the model input/output schema look like. Such dynamic signature
|
||||
// takes a opaque ptr as input, representing a ptr to a data structure
|
||||
// containing a set of dynamic memrefs wrapped in a vector; similarly the
|
||||
// output is also a opaque ptr to a data structure with output memrefs
|
||||
// wrapped within it.
|
||||
auto staticEntryPointFuncName =
|
||||
op.getAttrOfType<SymbolRefAttr>(
|
||||
KrnlEntryPointOp::getEntryPointFuncAttrName())
|
||||
.getLeafReference();
|
||||
auto dynEntryPointName = "_dyn_entry_point_" + staticEntryPointFuncName;
|
||||
assert(module.lookupSymbol(dynEntryPointName.str()) == nullptr &&
|
||||
"dynamic entry point name is not unique");
|
||||
rewriter.eraseOp(op);
|
||||
auto dynEntryPointFuncTy =
|
||||
LLVMType::getFunctionTy(opaquePtrTy, {opaquePtrTy}, false);
|
||||
auto dynamicEntryPointFunc = rewriter.create<LLVM::LLVMFuncOp>(
|
||||
loc, dynEntryPointName.str(), dynEntryPointFuncTy);
|
||||
auto &entryPointEntryBlock =
|
||||
createEntryBlock(dynEntryPointFuncTy, dynamicEntryPointFunc);
|
||||
rewriter.setInsertionPointToStart(&entryPointEntryBlock);
|
||||
|
||||
// Based on the static entry point type signature, unpack dynamic memory
|
||||
// refs to corresponding static memory refs.
|
||||
auto *staticEntryPointFunc = module.lookupSymbol(staticEntryPointFuncName);
|
||||
assert(staticEntryPointFunc &&
|
||||
isa<LLVM::LLVMFuncOp>(staticEntryPointFunc) &&
|
||||
"entry point func must exist and be an llvm func op");
|
||||
auto staticEntryPointTy = dyn_cast<LLVM::LLVMFuncOp>(staticEntryPointFunc)
|
||||
.getType()
|
||||
.dyn_cast<LLVMType>();
|
||||
|
||||
// Retrieve dynamic mem refs from wrapped input, and convert every one of
|
||||
// them to static mem refs.
|
||||
SmallVector<Value *, 4> staticInputs;
|
||||
auto wrappedInput = entryPointEntryBlock.getArgument(0);
|
||||
for (size_t i = 0; i < staticEntryPointTy.getFunctionNumParams(); i++) {
|
||||
// Call API function to retrieve the i-th dynamic memref.
|
||||
auto idxVal = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Ty, rewriter.getI32IntegerAttr(i));
|
||||
auto dynMemRef = callApi(rewriter, loc, apiRegistry, API::GET_DYN_MEM_REF,
|
||||
{wrappedInput, idxVal});
|
||||
|
||||
// Create a (static) memref type corresponding to the i-th memref input to
|
||||
// the inference function on stack, and load it to memRef.
|
||||
auto memRefPtrTy = staticEntryPointTy.getFunctionParamType(i);
|
||||
auto memRefTy = memRefPtrTy.getPointerElementTy();
|
||||
auto one = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Ty, rewriter.getI32IntegerAttr(1));
|
||||
Value *ptrToMemRef =
|
||||
rewriter.create<LLVM::AllocaOp>(loc, memRefPtrTy, one,
|
||||
/*alignment=*/0);
|
||||
|
||||
// Fill in the memref underlying ptrToMemRef with information extracted
|
||||
// from dynMemRef.
|
||||
fillPtrToMemRefWithDynMemRef(*dynMemRef, *ptrToMemRef, rewriter, loc,
|
||||
apiRegistry, llvmDialect);
|
||||
|
||||
// ptrToMemRef will be an input to main computation graph function.
|
||||
staticInputs.emplace_back(ptrToMemRef);
|
||||
}
|
||||
|
||||
// If more than one output exists, the struct becomes a nested struct,
|
||||
// the unpacking logic can be more involved, so no support for now.
|
||||
assert(numOutputs == 1 && "only support 1 output tensor now.");
|
||||
|
||||
// Call static entry point with the memref ptrs created, and get output.
|
||||
auto outputMemRefs = rewriter.create<LLVM::CallOp>(
|
||||
loc, staticEntryPointTy.getFunctionResultType(),
|
||||
rewriter.getSymbolRefAttr(staticEntryPointFuncName), staticInputs);
|
||||
|
||||
// Create wrapped output.
|
||||
auto wrappedOutput = callApi(rewriter, loc, apiRegistry,
|
||||
API::CREATE_ORDERED_DYN_MEM_REF_DICT, {});
|
||||
|
||||
// Get the first memref returned, convert to a dynamic memref and store
|
||||
// it in the wrapped Output.
|
||||
auto outMemRef = outputMemRefs.getResult(0);
|
||||
auto outMemRefTy = outMemRef->getType().dyn_cast<LLVMType>();
|
||||
auto outMemRefRank =
|
||||
outMemRefTy.getStructElementType(3).getArrayNumElements();
|
||||
auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
|
||||
auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
|
||||
API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
|
||||
fillDynMemRefWithMemRef(*outMemRef, *outDynMemRef, rewriter, loc,
|
||||
apiRegistry, llvmDialect);
|
||||
auto zero = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Ty, rewriter.getI32IntegerAttr(0));
|
||||
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
|
||||
{wrappedOutput, zero, outDynMemRef});
|
||||
|
||||
// Return wrapped output.
|
||||
rewriter.create<LLVM::ReturnOp>(loc,
|
||||
SmallVector<Value *, 1>({wrappedOutput}));
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
private:
|
||||
using ApiRegistry = std::map<API, ApiSpec>;
|
||||
|
||||
ApiRegistry RegisterAllApis(ModuleOp &module, PatternRewriter &rewriter,
|
||||
LLVM::LLVMDialect *llvmDialect) const {
|
||||
using LLVMType = LLVM::LLVMType;
|
||||
auto voidTy = LLVMType::getVoidTy(llvmDialect);
|
||||
auto opaquePtrTy = LLVMType::getInt8PtrTy(llvmDialect);
|
||||
auto int32Ty = LLVMType::getInt32Ty(llvmDialect);
|
||||
auto int64Ty = LLVMType::getInt64Ty(llvmDialect);
|
||||
auto int64PtrTy = int64Ty.getPointerTo();
|
||||
|
||||
// Declare API type as an enum value, its string name and an LLVM Type
|
||||
// specifying its signature.
|
||||
// clang-format off
|
||||
std::vector<ApiSpec> apiSpecs = {
|
||||
ApiSpec(API::CREATE_ORDERED_DYN_MEM_REF_DICT, "createOrderedDynMemRefDict", opaquePtrTy, {}),
|
||||
ApiSpec(API::CREATE_DYN_MEM_REF, "createDynMemRef", opaquePtrTy, {int32Ty}),
|
||||
ApiSpec(API::GET_DATA, "getData", opaquePtrTy, {opaquePtrTy}),
|
||||
ApiSpec(API::SET_DATA, "setData", voidTy, {opaquePtrTy, opaquePtrTy}),
|
||||
ApiSpec(API::GET_DYN_MEM_REF, "getDynMemRef", opaquePtrTy, {opaquePtrTy, int32Ty}),
|
||||
ApiSpec(API::SET_DYN_MEM_REF, "setDynMemRef", voidTy, {opaquePtrTy, int32Ty, opaquePtrTy}),
|
||||
ApiSpec(API::GET_SIZES, "getSizes", int64PtrTy, {opaquePtrTy}),
|
||||
ApiSpec(API::GET_STRIDES, "getStrides", int64PtrTy, {opaquePtrTy})
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
// Declare APIs in the current module and build an API registry mapping api
|
||||
// identities to a symbol reference to the API function.
|
||||
ApiRegistry registry;
|
||||
for (auto &apiSpec : apiSpecs) {
|
||||
apiSpec.symbolRef = getOrInsertExternFunc(apiSpec.name, module,
|
||||
apiSpec.funcTy(), rewriter);
|
||||
registry.emplace(apiSpec.id, apiSpec);
|
||||
}
|
||||
|
||||
return registry;
|
||||
}
|
||||
|
||||
// Call a registered API, return the return SSA values if only one result is
|
||||
// returned, otherwise return nullptr.
|
||||
Value *callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry,
|
||||
API apiId, ArrayRef<Value *> params) const {
|
||||
auto returnVals = rewriter.create<LLVM::CallOp>(
|
||||
loc, registry.at(apiId).outputTy, registry.at(apiId).symbolRef,
|
||||
ArrayRef<Value *>(params));
|
||||
if (returnVals.getNumResults() == 1)
|
||||
return returnVals.getResult(0);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Helper function to insert an entry block to LLVM function.
|
||||
// (TODO): upstream this to MLIR.
|
||||
Block &createEntryBlock(LLVM::LLVMType &dynEntryPointFuncType,
|
||||
LLVM::LLVMFuncOp &dynamicEntryPointFunc) const {
|
||||
// Add entry block:
|
||||
auto *entryPointEntryBlock = new Block();
|
||||
dynamicEntryPointFunc.push_back(entryPointEntryBlock);
|
||||
llvm::SmallVector<Type, 4> argTypes;
|
||||
for (size_t i = 0; i < dynEntryPointFuncType.getFunctionNumParams(); i++)
|
||||
argTypes.emplace_back(dynEntryPointFuncType.getFunctionParamType(i));
|
||||
entryPointEntryBlock->addArguments(argTypes);
|
||||
return *entryPointEntryBlock;
|
||||
}
|
||||
|
||||
void fillPtrToMemRefWithDynMemRef(Value &dynMemRef, Value &ptrToMemRef,
|
||||
PatternRewriter &rewriter,
|
||||
const Location &loc,
|
||||
const std::map<API, ApiSpec> &apiRegistry,
|
||||
LLVM::LLVMDialect *llvmDialect) const {
|
||||
auto memRefPtrTy = ptrToMemRef.getType().dyn_cast<LLVM::LLVMType>();
|
||||
auto memRefTy = memRefPtrTy.getPointerElementTy();
|
||||
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||
|
||||
Value *memRef =
|
||||
rewriter.create<LLVM::LoadOp>(loc, memRefPtrTy, &ptrToMemRef);
|
||||
|
||||
// Set dataPtr and alignedDataPtr;
|
||||
auto dataPtr =
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_DATA, {&dynMemRef});
|
||||
dataPtr = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, memRefTy.getStructElementType(0), dataPtr);
|
||||
memRef = rewriter.create<LLVM::InsertValueOp>(
|
||||
loc, memRefTy, memRef, dataPtr,
|
||||
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(0)}));
|
||||
memRef = rewriter.create<LLVM::InsertValueOp>(
|
||||
loc, memRefTy, memRef, dataPtr,
|
||||
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(1)}));
|
||||
|
||||
// Use zero offset now.
|
||||
auto zero = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int64Ty, rewriter.getI64IntegerAttr(0));
|
||||
memRef = rewriter.create<LLVM::InsertValueOp>(
|
||||
loc, memRefTy, memRef, zero,
|
||||
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(2)}));
|
||||
|
||||
// Get rank, sizes array ptr and strides array ptr.
|
||||
auto rank = memRefTy.getStructElementType(3).getArrayNumElements();
|
||||
auto sizesArrayPtr =
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {&dynMemRef});
|
||||
auto stridesArrayPtr =
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {&dynMemRef});
|
||||
|
||||
for (decltype(rank) i = 0; i < rank; i++) {
|
||||
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int64Ty, rewriter.getI64IntegerAttr(i));
|
||||
|
||||
// Insert size of the dimension.
|
||||
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, int64Ty.getPointerTo(), sizesArrayPtr,
|
||||
ArrayRef<Value *>({dimIdx}));
|
||||
auto dimSize = rewriter.create<LLVM::LoadOp>(loc, int64Ty.getPointerTo(),
|
||||
dimSizePtr);
|
||||
memRef = rewriter.create<LLVM::InsertValueOp>(
|
||||
loc, memRefTy, memRef, dimSize,
|
||||
rewriter.getArrayAttr(
|
||||
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
|
||||
|
||||
// Insert stride of the dimension.
|
||||
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, int64Ty.getPointerTo(), sizesArrayPtr,
|
||||
ArrayRef<Value *>({dimIdx}));
|
||||
auto dimStride = rewriter.create<LLVM::LoadOp>(
|
||||
loc, int64Ty.getPointerTo(), dimStridePtr);
|
||||
memRef = rewriter.create<LLVM::InsertValueOp>(
|
||||
loc, memRefTy, memRef, dimStride,
|
||||
rewriter.getArrayAttr(
|
||||
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
|
||||
}
|
||||
|
||||
rewriter.create<LLVM::StoreOp>(loc, memRef, &ptrToMemRef);
|
||||
}
|
||||
|
||||
void fillDynMemRefWithMemRef(Value &outMemRef, Value &outDynMemRef,
|
||||
PatternRewriter &rewriter, const Location &loc,
|
||||
const std::map<API, ApiSpec> &apiRegistry,
|
||||
LLVM::LLVMDialect *llvmDialect) const {
|
||||
auto outMemRefTy = outMemRef.getType().dyn_cast<LLVM::LLVMType>();
|
||||
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||
|
||||
// Extract the data pointer, and record it in dynamic mem ref created.
|
||||
Value *outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, outMemRefTy.getStructElementType(0), &outMemRef,
|
||||
rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)}));
|
||||
outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr);
|
||||
callApi(rewriter, loc, apiRegistry, API::SET_DATA,
|
||||
{&outDynMemRef, outMemRefDataPtr});
|
||||
|
||||
auto rank = outMemRefTy.getStructElementType(3).getArrayNumElements();
|
||||
auto sizesArrayPtr =
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {&outDynMemRef});
|
||||
auto stridesArrayPtr =
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {&outDynMemRef});
|
||||
|
||||
for (decltype(rank) i = 0; i < rank; i++) {
|
||||
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int64Ty, rewriter.getI64IntegerAttr(i));
|
||||
|
||||
// Transfer size of dimension from memref to dynamic memref.
|
||||
auto dimSize = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, int64Ty, &outMemRef,
|
||||
rewriter.getArrayAttr(
|
||||
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
|
||||
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, int64Ty.getPointerTo(), sizesArrayPtr,
|
||||
ArrayRef<Value *>({dimIdx}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, dimSize, dimSizePtr);
|
||||
|
||||
// Transfer stride of dimension from memref to dynamic memref.
|
||||
auto dimStride = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, int64Ty, &outMemRef,
|
||||
rewriter.getArrayAttr(
|
||||
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
|
||||
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, int64Ty.getPointerTo(), stridesArrayPtr,
|
||||
ArrayRef<Value *>({dimIdx}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, dimStride, dimStridePtr);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // end namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// KRNL + Stadard + Affine dialects lowering to LLVM.
|
||||
|
@ -109,7 +465,7 @@ namespace {
|
|||
struct KrnlToLLVMLoweringPass : public ModulePass<KrnlToLLVMLoweringPass> {
|
||||
void runOnModule() final;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
} // end anonymous namespace
|
||||
|
||||
void KrnlToLLVMLoweringPass::runOnModule() {
|
||||
// Define the target for this lowering i.e. the LLVM dialect.
|
||||
|
@ -128,12 +484,13 @@ void KrnlToLLVMLoweringPass::runOnModule() {
|
|||
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||
|
||||
// Lower from the `krnl` dialect i.e. the Reshape operation.
|
||||
patterns.insert<KrnlMemcpyOpLowering>(&getContext());
|
||||
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering>(
|
||||
&getContext());
|
||||
|
||||
// We want to completely lower to LLVM, so we use a `FullConversion`. This
|
||||
// ensures that only legal operations will remain after the conversion.
|
||||
auto module = getModule();
|
||||
if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
|
||||
if (failed(
|
||||
applyFullConversion(getModule(), target, patterns, &typeConverter)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
|
@ -142,5 +499,5 @@ std::unique_ptr<mlir::Pass> mlir::createKrnlLowerToLLVMPass() {
|
|||
return std::make_unique<KrnlToLLVMLoweringPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<KrnlToLLVMLoweringPass> pass(
|
||||
"lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM.");
|
||||
static PassRegistration<KrnlToLLVMLoweringPass>
|
||||
pass("lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM.");
|
||||
|
|
|
@ -7,10 +7,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <tuple>
|
||||
|
||||
#include "llvm/Bitcode/BitcodeWriter.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
|
@ -24,9 +21,7 @@
|
|||
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||
#include "src/compiler/pass/passes.hpp"
|
||||
|
||||
#include "mlir/Analysis/Verifier.h"
|
||||
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
add_library(cruntime
|
||||
dyn_memref.cpp
|
||||
dyn_memref.h
|
||||
data_type.h)
|
||||
target_include_directories(cruntime
|
||||
PRIVATE ${DLC_SRC_ROOT} ${DLC_BIN_ROOT}
|
||||
${DLC_SRC_ROOT})
|
||||
|
||||
pybind11_add_module(pyruntime
|
||||
dyn_memref.cpp
|
||||
dyn_memref.h
|
||||
runtime.cpp
|
||||
runtime.hpp)
|
||||
target_link_libraries(pyruntime PRIVATE ${CMAKE_DL_LIBS})
|
||||
target_include_directories(pyruntime
|
||||
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
||||
${ONNF_SRC_ROOT})
|
|
@ -0,0 +1,30 @@
|
|||
enum DYN_MEMREF_DATA_TYPE {
|
||||
UNDEFINED = 0;
|
||||
// Basic types.
|
||||
FLOAT = 1; // float
|
||||
UINT8 = 2; // uint8_t
|
||||
INT8 = 3; // int8_t
|
||||
UINT16 = 4; // uint16_t
|
||||
INT16 = 5; // int16_t
|
||||
INT32 = 6; // int32_t
|
||||
INT64 = 7; // int64_t
|
||||
STRING = 8; // string
|
||||
BOOL = 9; // bool
|
||||
|
||||
// IEEE754 half-precision floating-point format (16 bits wide).
|
||||
// This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
|
||||
FLOAT16 = 10;
|
||||
|
||||
DOUBLE = 11;
|
||||
UINT32 = 12;
|
||||
UINT64 = 13;
|
||||
COMPLEX64 = 14; // complex with float32 real and imaginary components
|
||||
COMPLEX128 = 15; // complex with float64 real and imaginary components
|
||||
|
||||
// Non-IEEE floating-point format based on IEEE754 single-precision
|
||||
// floating-point number truncated to 16 bits.
|
||||
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
|
||||
BFLOAT16 = 16;
|
||||
|
||||
// Future extensions go here.
|
||||
};
|
|
@ -0,0 +1,74 @@
|
|||
#include <cassert>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "dyn_memref.h"
|
||||
|
||||
DynMemRef::DynMemRef(int _rank) {
|
||||
rank = _rank;
|
||||
sizes = (INDEX_TYPE *)malloc(rank * sizeof(INDEX_TYPE));
|
||||
strides = (int64_t *)malloc(rank * sizeof(int64_t));
|
||||
}
|
||||
|
||||
// An ordered dynamic MemRef dictionary.
|
||||
// The goal is to support accessing dynamic memory ref by name and by index.
|
||||
// Currently, only accessing by index is supported.
|
||||
struct OrderedDynMemRefDict {
|
||||
std::map<std::string, DynMemRef *> tensorDict;
|
||||
std::vector<std::string> orderedNames;
|
||||
};
|
||||
|
||||
int numDynMemRefs(OrderedDynMemRefDict *dict) {
|
||||
return dict->orderedNames.size();
|
||||
}
|
||||
|
||||
OrderedDynMemRefDict *createOrderedDynMemRefDict() {
|
||||
return new OrderedDynMemRefDict();
|
||||
}
|
||||
|
||||
DynMemRef *createDynMemRef(int rank) { return new DynMemRef(rank); }
|
||||
|
||||
DynMemRef *getDynMemRef(OrderedDynMemRefDict *tensorDict, int idx) {
|
||||
return tensorDict->tensorDict[tensorDict->orderedNames[idx]];
|
||||
}
|
||||
|
||||
void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx,
|
||||
DynMemRef *tensor) {
|
||||
if (tensorDict->orderedNames.capacity() <= idx)
|
||||
tensorDict->orderedNames.resize(idx + 1);
|
||||
|
||||
// The dynamic memref is essentially anonymous, since we are storing it by
|
||||
// indexed position.
|
||||
// TODO: can use random string as names to reduce chance of collision.
|
||||
auto unique_name = std::to_string(idx);
|
||||
assert(tensorDict->tensorDict.count(unique_name) == 0 &&
|
||||
"duplicate dynamic mem ref name");
|
||||
|
||||
tensorDict->orderedNames[idx] = unique_name;
|
||||
tensorDict->tensorDict[tensorDict->orderedNames[idx]] = tensor;
|
||||
}
|
||||
|
||||
void *getData(DynMemRef *dynMemRef) { return dynMemRef->data; }
|
||||
|
||||
void setData(DynMemRef *dynMemRef, void *dataPtr) { dynMemRef->data = dataPtr; }
|
||||
|
||||
void *getAlignedData(DynMemRef *dynMemRef) { return dynMemRef->alignedData; }
|
||||
|
||||
void setAlignedData(DynMemRef *dynMemRef, void *dataPtr) {
|
||||
dynMemRef->alignedData = dataPtr;
|
||||
}
|
||||
|
||||
INDEX_TYPE *getSizes(DynMemRef *dynMemRef) { return dynMemRef->sizes; }
|
||||
|
||||
void setSizes(DynMemRef *dynMemRef, INDEX_TYPE *sizes) {
|
||||
for (int i = 0; i < dynMemRef->rank; i++)
|
||||
dynMemRef->sizes[i] = sizes[i];
|
||||
}
|
||||
|
||||
int64_t *getStrides(DynMemRef *dynMemRef) { return dynMemRef->strides; }
|
||||
|
||||
void setStrides(DynMemRef *dynMemRef, int64_t *strides) {
|
||||
for (int i = 0; i < dynMemRef->rank; i++)
|
||||
dynMemRef->sizes[i] = strides[i];
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
typedef int64_t INDEX_TYPE;
|
||||
|
||||
// This is a dynamic version of memref.
|
||||
// The same struct can be used to represent memrefs of
|
||||
// all ranks and type combinations.
|
||||
struct DynMemRef {
|
||||
void *data;
|
||||
void *alignedData;
|
||||
INDEX_TYPE offset;
|
||||
|
||||
unsigned int rank;
|
||||
INDEX_TYPE *sizes;
|
||||
int64_t *strides;
|
||||
|
||||
DynMemRef(int _rank);
|
||||
};
|
||||
|
||||
extern "C" {
|
||||
|
||||
// Ordered DynMemRef Dictionary is a data structure for wrapping the input
|
||||
// dynmemrefs so that they can be addressed both by index and by name.
|
||||
struct OrderedDynMemRefDict;
|
||||
|
||||
// Get number of dynamic memrefs in OrderedDynMemRefDict dict.
|
||||
int numDynMemRefs(OrderedDynMemRefDict *dict);
|
||||
|
||||
// Create an ordered dynmemref dictionary.
|
||||
OrderedDynMemRefDict *createOrderedDynMemRefDict();
|
||||
|
||||
// Create a dynmemref with a certain rank.
|
||||
DynMemRef *createDynMemRef(int rank);
|
||||
|
||||
// Get the i-th dynmemref from orderedDict.
|
||||
DynMemRef *getDynMemRef(OrderedDynMemRefDict *orderedDict, int i);
|
||||
|
||||
// Set the i-th dynmemref in orderedDict to be dynMemRef.
|
||||
void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx,
|
||||
DynMemRef *dynMemRef);
|
||||
|
||||
// Get data pointer from dynMemRef.
|
||||
void *getData(DynMemRef *dynMemRef);
|
||||
|
||||
// Set data pointer for dynMemRef.
|
||||
void setData(DynMemRef *dynMemRef, void *data);
|
||||
|
||||
// Get algined data pointer from dynMemRef.
|
||||
void *getAlignedData(DynMemRef *);
|
||||
|
||||
// Set aligned data pointer for dynMemRef.
|
||||
void setAlignedData(DynMemRef *, void *);
|
||||
|
||||
// Get ptr to sizes array.
|
||||
INDEX_TYPE *getSizes(DynMemRef *);
|
||||
|
||||
// Get ptr to strides array.
|
||||
int64_t *getStrides(DynMemRef *);
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
#include "runtime.hpp"
|
||||
|
||||
ExecutionSession::ExecutionSession(std::string sharedLibPath,
|
||||
std::string entryPointName) {
|
||||
_sharedLibraryHandle = dlopen(sharedLibPath.c_str(), RTLD_LAZY);
|
||||
_entryPointFunc =
|
||||
(entryPointFuncType)dlsym(_sharedLibraryHandle, entryPointName.c_str());
|
||||
}
|
||||
|
||||
std::vector<py::array>
|
||||
ExecutionSession::run(std::vector<py::array> inputsPyArray) {
|
||||
assert(_entryPointFunc && "entry point not loaded");
|
||||
auto *wrappedInput = createOrderedDynMemRefDict();
|
||||
int inputIdx = 0;
|
||||
for (auto inputPyArray : inputsPyArray) {
|
||||
auto *inputDynMemRef = createDynMemRef(inputPyArray.ndim());
|
||||
assert(inputPyArray.flags() && py::array::c_style &&
|
||||
"expect contiguous python array");
|
||||
|
||||
if (inputPyArray.writeable()) {
|
||||
inputDynMemRef->data = inputPyArray.mutable_data();
|
||||
inputDynMemRef->alignedData = inputPyArray.mutable_data();
|
||||
} else {
|
||||
// If data is not writable, copy them to a writable buffer.
|
||||
auto *copiedData = (float *)malloc(inputPyArray.nbytes());
|
||||
memcpy(copiedData, inputPyArray.data(), inputPyArray.nbytes());
|
||||
inputDynMemRef->data = copiedData;
|
||||
inputDynMemRef->alignedData = copiedData;
|
||||
}
|
||||
|
||||
for (int i = 0; i < inputPyArray.ndim(); i++) {
|
||||
inputDynMemRef->sizes[i] = inputPyArray.shape(i);
|
||||
inputDynMemRef->strides[i] = inputPyArray.strides(i);
|
||||
}
|
||||
|
||||
setDynMemRef(wrappedInput, inputIdx++, inputDynMemRef);
|
||||
}
|
||||
|
||||
std::vector<py::array> outputPyArrays;
|
||||
auto *wrappedOutput = _entryPointFunc(wrappedInput);
|
||||
for (int i = 0; i < numDynMemRefs(wrappedOutput); i++) {
|
||||
auto *dynMemRef = getDynMemRef(wrappedOutput, i);
|
||||
auto shape = std::vector<int64_t>(dynMemRef->sizes,
|
||||
dynMemRef->sizes + dynMemRef->rank);
|
||||
outputPyArrays.emplace_back(
|
||||
py::array(py::dtype("float32"), shape, dynMemRef->data));
|
||||
}
|
||||
|
||||
return outputPyArrays;
|
||||
}
|
||||
|
||||
ExecutionSession::~ExecutionSession() { dlclose(_sharedLibraryHandle); }
|
|
@ -0,0 +1,37 @@
|
|||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <string>
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "src/runtime/dyn_memref.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
typedef OrderedDynMemRefDict *(*entryPointFuncType)(OrderedDynMemRefDict *);
|
||||
|
||||
class ExecutionSession {
|
||||
public:
|
||||
ExecutionSession(std::string sharedLibPath, std::string entryPointName);
|
||||
|
||||
std::vector<py::array> run(std::vector<py::array> inputsPyArray);
|
||||
|
||||
~ExecutionSession();
|
||||
|
||||
private:
|
||||
// Handler to the shared library file being loaded.
|
||||
void *_sharedLibraryHandle = nullptr;
|
||||
|
||||
// Entry point function.
|
||||
entryPointFuncType _entryPointFunc = nullptr;
|
||||
};
|
||||
|
||||
PYBIND11_MODULE(pyruntime, m) {
|
||||
py::class_<ExecutionSession>(m, "ExecutionSession")
|
||||
.def(py::init<const std::string &, const std::string &>())
|
||||
.def("run", &ExecutionSession::run);
|
||||
}
|
|
@ -0,0 +1,162 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import unittest
|
||||
import onnx.backend.base
|
||||
import onnx.backend.test
|
||||
|
||||
from onnx.backend.base import Device, DeviceType
|
||||
import onnx.shape_inference
|
||||
import onnx.version_converter
|
||||
from typing import Optional, Text, Any, Tuple, Sequence
|
||||
from onnx import NodeProto, ModelProto, TensorProto
|
||||
import numpy # type: ignore
|
||||
import subprocess
|
||||
from pyruntime import ExecutionSession
|
||||
|
||||
CXX = os.getenv('CXX')
|
||||
FE = os.getenv('FE')
|
||||
LLC = os.getenv('LLC')
|
||||
RT_DIR = os.getenv('RT_DIR')
|
||||
assert CXX and FE and LLC and RT_DIR, "tools path not set"
|
||||
|
||||
class DummyBackend(onnx.backend.base.Backend):
|
||||
@classmethod
|
||||
def prepare(
|
||||
cls,
|
||||
model,
|
||||
device='CPU',
|
||||
**kwargs
|
||||
):
|
||||
super(DummyBackend, cls).prepare(model, device, **kwargs)
|
||||
# Save model to disk as temp_model.onnx.
|
||||
onnx.save(model, "temp_model.onnx")
|
||||
# Call frontend to process temp_model.onnx, bit code will be generated.
|
||||
subprocess.run([FE, "temp_model.onnx"], stdout=subprocess.PIPE)
|
||||
# Call llc to generate object file from bitcode.
|
||||
subprocess.run([LLC, "-filetype=obj", "model.bc"],
|
||||
stdout=subprocess.PIPE)
|
||||
# Generate shared library from object file, linking with c runtime.
|
||||
subprocess.run([
|
||||
CXX, "-shared", "model.o", "-o", "model.so", "-L" + RT_DIR,
|
||||
"-lcruntime"
|
||||
],
|
||||
stdout=subprocess.PIPE)
|
||||
return ExecutionSession("./model.so", "_dyn_entry_point_main_graph")
|
||||
|
||||
@classmethod
|
||||
def supports_device(cls, device):
|
||||
d = Device(device)
|
||||
if d.type == DeviceType.CPU:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
backend_test = onnx.backend.test.BackendTest(DummyBackend, __name__)
|
||||
|
||||
# Test directories:
|
||||
# https://github.com/onnx/onnx/tree/master/onnx/backend/test/data/node
|
||||
|
||||
test_to_enable = [
|
||||
# Add Op:
|
||||
"test_add_cpu",
|
||||
"test_add_bcast_cpu",
|
||||
|
||||
# And Op:
|
||||
|
||||
# Sub Op:
|
||||
"test_sub_cpu",
|
||||
"test_sub_bcast_cpu",
|
||||
"test_sub_example_cpu",
|
||||
|
||||
# Cosh Op:
|
||||
"test_cosh_cpu",
|
||||
"test_cosh_example_cpu",
|
||||
|
||||
# Div Op:
|
||||
"test_div_cpu",
|
||||
"test_div_bcast_cpu",
|
||||
"test_div_example_cpu",
|
||||
|
||||
# Elu Op:
|
||||
"test_elu_cpu",
|
||||
"test_elu_default_cpu",
|
||||
"test_elu_example_cpu",
|
||||
|
||||
# Exp Op:
|
||||
"test_exp_cpu",
|
||||
"test_exp_example_cpu",
|
||||
|
||||
# Hard Sigmoid Op:
|
||||
"test_hardsigmoid_cpu",
|
||||
"test_hardsigmoid_default_cpu",
|
||||
"test_hardsigmoid_example_cpu",
|
||||
|
||||
# Leaky Relu Op:
|
||||
"test_leakyrelu_cpu",
|
||||
"test_leakyrelu_default_cpu",
|
||||
"test_leakyrelu_example_cpu",
|
||||
|
||||
# Max Op:
|
||||
# "test_max_example_cpu", <- error
|
||||
"test_max_one_input_cpu",
|
||||
# "test_max_two_inputs_cpu", <- error
|
||||
|
||||
# Min Op:
|
||||
# "test_min_example_cpu", <- error
|
||||
"test_min_one_input_cpu",
|
||||
# "test_min_two_inputs_cpu", <- error
|
||||
|
||||
# Mul Op:
|
||||
"test_mul_cpu",
|
||||
"test_mul_bcast_cpu",
|
||||
"test_mul_example_cpu",
|
||||
|
||||
# Relu Op:
|
||||
"test_relu_cpu",
|
||||
|
||||
# Selu Op:
|
||||
"test_selu_cpu",
|
||||
"test_selu_default_cpu",
|
||||
"test_selu_example_cpu",
|
||||
|
||||
# Sigmoid Op:
|
||||
"test_sigmoid_cpu",
|
||||
"test_sigmoid_example_cpu",
|
||||
|
||||
# Sum Op:
|
||||
#"test_sum_example_cpu", <- error
|
||||
"test_sum_one_input_cpu",
|
||||
#"test_sum_two_inputs_cpu", <- error
|
||||
|
||||
# Reciprocal Op:
|
||||
#"test_reciprocal_cpu", <- error on shape inference.
|
||||
#"test_reciprocal_example_cpu", <- error on shape inference.
|
||||
]
|
||||
|
||||
# Extract name of all test cases.
|
||||
import inspect
|
||||
all_tests = inspect.getmembers(
|
||||
backend_test.test_cases["OnnxBackendNodeModelTest"])
|
||||
all_test_names = list(map(lambda x: x[0], all_tests))
|
||||
for test_name in test_to_enable:
|
||||
assert test_name in all_test_names, "test name {} not found".format(test_name)
|
||||
backend_test.include(r"^{}$".format(test_name))
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
print()
|
||||
print("*" * 40)
|
||||
print("A total of {} tests should have run".format(len(test_to_enable)))
|
||||
print("*" * 40)
|
||||
|
||||
|
||||
# import all test cases at global scope to make them visible to python.unittest
|
||||
globals().update(backend_test.test_cases)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue