Support model debugging (#138)
* Call llc, ld from within onnx-mlir. * Rename EmitLLVMBC -> EmitLib., reorder header files * Checkpoint, debug.py works. * Automatically generate inputs in debug.py. * Use float. * Fix merge conflict, remove RapidCheck from this patch. * Remove submodule rapidcheck properly. * Reformat code. * More comments. * Add documentation. * Add documentation to navigation. * Account for the fact that some initializers may also appear as input.
This commit is contained in:
parent
be62d85a32
commit
a270af5ce0
|
@ -0,0 +1,26 @@
|
||||||
|
# Debugging Numerical Error
|
||||||
|
|
||||||
|
Use `util/debug.py` python script to debug numerical errors, when onnx-mlir-compiled inference executable produces
|
||||||
|
numerical results that are inconsistent with those produced by the training framework.
|
||||||
|
This python script will run the model through onnx-mlir and a reference backend, and compare
|
||||||
|
the intermediate results produced by these two backends layer by layer.
|
||||||
|
|
||||||
|
## Rrerequisite
|
||||||
|
- Set `ONNX_MLIR_HOME` environment variable to be the path to
|
||||||
|
the HOME directory for onnx-mlir. The HOME directory for onnx-mlir refers to
|
||||||
|
the parent folder containing the `bin`, `lib`, etc sub-folders in which ONNX-MLIR
|
||||||
|
executables and libraries can be found.
|
||||||
|
- Install an ONNX backend, by default onnx-runtime is used as testing backend. Install by
|
||||||
|
running `pip install onnxruntime`. To use a different testing backend, simply replace code
|
||||||
|
importing onnxruntime to some other ONNX-compliant backend.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
`util/debug.py` supports the following command-line options:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
usage: debug.py [-h] model_path
|
||||||
|
|
||||||
|
positional arguments:
|
||||||
|
model_path Path to the model to debug.
|
||||||
|
```
|
|
@ -29,5 +29,7 @@ toc:
|
||||||
# url: /piece1.html
|
# url: /piece1.html
|
||||||
- title: Tools
|
- title: Tools
|
||||||
subfolderitems:
|
subfolderitems:
|
||||||
|
- page: debug.py - Debug Numerical Errors
|
||||||
|
url: /DebuggingNumericalError.html
|
||||||
- page: DocCheck - Handling Necessary Code Duplication
|
- page: DocCheck - Handling Necessary Code Duplication
|
||||||
url: /doc_check/
|
url: /doc_check/
|
|
@ -10,12 +10,13 @@
|
||||||
|
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <fcntl.h>
|
#include <fcntl.h>
|
||||||
|
#include <llvm/Support/Program.h>
|
||||||
#include "llvm/Support/Program.h"
|
|
||||||
|
|
||||||
#include "src/ExternalUtil.hpp"
|
#include "src/ExternalUtil.hpp"
|
||||||
#include "src/MainUtils.hpp"
|
#include "src/MainUtils.hpp"
|
||||||
|
|
||||||
|
#include "MainUtils.hpp"
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
#include <io.h>
|
#include <io.h>
|
||||||
#else
|
#else
|
||||||
|
|
|
@ -387,36 +387,57 @@ public:
|
||||||
staticInputs.emplace_back(ptrToMemRef);
|
staticInputs.emplace_back(ptrToMemRef);
|
||||||
}
|
}
|
||||||
|
|
||||||
// If more than one output exists, the struct becomes a nested struct,
|
|
||||||
// the unpacking logic can be more involved, so no support for now.
|
|
||||||
assert(numOutputs == 1 && "only support 1 output tensor now.");
|
|
||||||
|
|
||||||
// Call static entry point with the memref ptrs created, and get output.
|
// Call static entry point with the memref ptrs created, and get output.
|
||||||
auto outputMemRefs = rewriter.create<LLVM::CallOp>(loc,
|
auto outMemRefs =
|
||||||
staticEntryPointTy.getFunctionResultType(),
|
rewriter
|
||||||
rewriter.getSymbolRefAttr(wrappedStaticEntryPointFuncName),
|
.create<LLVM::CallOp>(loc,
|
||||||
staticInputs);
|
staticEntryPointTy.getFunctionResultType(),
|
||||||
|
rewriter.getSymbolRefAttr(wrappedStaticEntryPointFuncName),
|
||||||
|
staticInputs)
|
||||||
|
.getResult(0);
|
||||||
|
auto outMemRefsType = outMemRefs.getType().dyn_cast<LLVMType>();
|
||||||
|
|
||||||
|
std::vector<mlir::Value> outMemRefList;
|
||||||
|
if (numOutputs == 1) {
|
||||||
|
// If only one output tensor exists, the tensor's corresponding memref
|
||||||
|
// descriptor will be returned as is.
|
||||||
|
outMemRefList.emplace_back(outMemRefs);
|
||||||
|
} else {
|
||||||
|
// Otherwise, if multiple tensors are to be returned, the returned value
|
||||||
|
// is a struct. Multiple tensors' memref descriptors are packed into the
|
||||||
|
// same struct. So we unpack them iteratively to outMemRefList.
|
||||||
|
for (int i = 0; i < numOutputs; i++) {
|
||||||
|
auto position = rewriter.getArrayAttr({rewriter.getI64IntegerAttr(i)});
|
||||||
|
auto type = outMemRefsType.getStructElementType(i);
|
||||||
|
auto extractOp = rewriter.create<LLVM::ExtractValueOp>(loc,
|
||||||
|
/*res=*/type,
|
||||||
|
/*type=*/outMemRefs,
|
||||||
|
/*position=*/position);
|
||||||
|
outMemRefList.emplace_back(extractOp.getResult());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Create wrapped output.
|
// Create wrapped output.
|
||||||
auto wrappedOutput = callApi(
|
auto wrappedOutput = callApi(
|
||||||
rewriter, loc, apiRegistry, API::CREATE_ORDERED_DYN_MEM_REF_DICT, {});
|
rewriter, loc, apiRegistry, API::CREATE_ORDERED_DYN_MEM_REF_DICT, {});
|
||||||
|
|
||||||
// Get the first memref returned, convert to a dynamic memref and store
|
for (decltype(numOutputs) i = 0; i < outMemRefList.size(); i++) {
|
||||||
// it in the wrapped Output.
|
// Get the i-th memref returned, convert to a dynamic memref and store it
|
||||||
auto outMemRef = outputMemRefs.getResult(0);
|
// in the wrappedOutput.
|
||||||
auto outMemRefTy = outMemRef.getType().dyn_cast<LLVMType>();
|
auto memRef = outMemRefList.at(i);
|
||||||
auto outMemRefRank = getRankFromMemRefType(outMemRefTy);
|
auto outMemRefTy = memRef.getType().dyn_cast<LLVMType>();
|
||||||
auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(
|
auto outMemRefRank = getRankFromMemRefType(outMemRefTy);
|
||||||
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
|
auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(
|
||||||
auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
|
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
|
||||||
API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
|
auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
|
||||||
fillDynMemRefWithMemRef(
|
API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
|
||||||
outMemRef, outDynMemRef, rewriter, loc, apiRegistry, llvmDialect);
|
fillDynMemRefWithMemRef(
|
||||||
auto zero = rewriter.create<LLVM::ConstantOp>(
|
memRef, outDynMemRef, rewriter, loc, apiRegistry, llvmDialect);
|
||||||
loc, int32Ty, rewriter.getI32IntegerAttr(0));
|
auto idx = rewriter.create<LLVM::ConstantOp>(
|
||||||
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
|
loc, int32Ty, rewriter.getI32IntegerAttr(i));
|
||||||
{wrappedOutput, zero, outDynMemRef});
|
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
|
||||||
|
{wrappedOutput, idx, outDynMemRef});
|
||||||
|
}
|
||||||
// Return wrapped output.
|
// Return wrapped output.
|
||||||
rewriter.create<LLVM::ReturnOp>(
|
rewriter.create<LLVM::ReturnOp>(
|
||||||
loc, SmallVector<Value, 1>({wrappedOutput}));
|
loc, SmallVector<Value, 1>({wrappedOutput}));
|
||||||
|
@ -613,7 +634,6 @@ void KrnlToLLVMLoweringPass::runOnOperation() {
|
||||||
ConversionTarget target(getContext());
|
ConversionTarget target(getContext());
|
||||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||||
// target.addLegalOp<KrnlEntryPointOp>();
|
|
||||||
|
|
||||||
// Lower the MemRef types to a representation in LLVM.
|
// Lower the MemRef types to a representation in LLVM.
|
||||||
LLVMTypeConverter typeConverter(&getContext());
|
LLVMTypeConverter typeConverter(&getContext());
|
||||||
|
@ -626,7 +646,6 @@ void KrnlToLLVMLoweringPass::runOnOperation() {
|
||||||
populateStdToLLVMConversionPatterns(typeConverter, patterns,
|
populateStdToLLVMConversionPatterns(typeConverter, patterns,
|
||||||
/*emitCWrapperS=*/true,
|
/*emitCWrapperS=*/true,
|
||||||
/*useAlignedAlloc=*/false);
|
/*useAlignedAlloc=*/false);
|
||||||
|
|
||||||
patterns.insert<KrnlGlobalOpLowering>(&getContext(), typeConverter);
|
patterns.insert<KrnlGlobalOpLowering>(&getContext(), typeConverter);
|
||||||
|
|
||||||
// Lower from the `krnl` dialect i.e. the Reshape operation.
|
// Lower from the `krnl` dialect i.e. the Reshape operation.
|
||||||
|
|
|
@ -0,0 +1,112 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import onnx
|
||||||
|
import subprocess
|
||||||
|
import numpy as np
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
# Reference backend, use onnxruntime by default
|
||||||
|
import onnxruntime
|
||||||
|
prepare = onnxruntime.InferenceSession
|
||||||
|
|
||||||
|
if (not os.environ.get('ONNX_MLIR_HOME', None)):
|
||||||
|
raise RuntimeError(
|
||||||
|
"Environment variable ONNX_MLIR_HOME is not set, please set it to the path to "
|
||||||
|
"the HOME directory for onnx-mlir. The HOME directory for onnx-mlir refers to "
|
||||||
|
"the parent folder containing the bin, lib, etc sub-folders in which ONNX-MLIR "
|
||||||
|
"executables and libraries can be found.")
|
||||||
|
|
||||||
|
VERBOSE = os.environ.get('VERBOSE', False)
|
||||||
|
ONNX_MLIR = os.path.join(os.environ['ONNX_MLIR_HOME'], "bin/onnx-mlir")
|
||||||
|
|
||||||
|
# Include runtime directory in python paths, so pyruntime can be imported.
|
||||||
|
RUNTIME_DIR = os.path.join(os.environ['ONNX_MLIR_HOME'], "lib")
|
||||||
|
sys.path.append(RUNTIME_DIR)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pyruntime import ExecutionSession
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Looks like you did not build the pyruntime target, build it by running `make pyruntime`."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def execute_commands(cmds):
|
||||||
|
if (VERBOSE):
|
||||||
|
print(" ".join(cmds))
|
||||||
|
subprocess.run(cmds, stdout=subprocess.PIPE, check=True)
|
||||||
|
|
||||||
|
|
||||||
|
def extend_model_output(model, intermediate_outputs):
|
||||||
|
# onnx-mlir doesn't care about manually specified output types & shapes.
|
||||||
|
DUMMY_TENSOR_TYPE = onnx.TensorProto.FLOAT
|
||||||
|
|
||||||
|
while (len(model.graph.output)):
|
||||||
|
model.graph.output.pop()
|
||||||
|
|
||||||
|
for output_name in intermediate_outputs:
|
||||||
|
output_value_info = onnx.helper.make_tensor_value_info(
|
||||||
|
output_name, DUMMY_TENSOR_TYPE, None)
|
||||||
|
model.graph.output.extend([output_value_info])
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def main(model_path):
|
||||||
|
model = onnx.load(model_path)
|
||||||
|
intermediate_outputs = sum(
|
||||||
|
[list(node.output) for node in model.graph.node], [])
|
||||||
|
intermediate_outputs = list(OrderedDict.fromkeys(intermediate_outputs))
|
||||||
|
model = extend_model_output(model, intermediate_outputs)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
print("Temporary directory has been created at {}".format(temp_dir))
|
||||||
|
|
||||||
|
# Save modified model & invoke onnx-mlir to compile it.
|
||||||
|
temp_model_path = os.path.join(temp_dir, "model.onnx")
|
||||||
|
onnx.save(model, temp_model_path)
|
||||||
|
execute_commands([ONNX_MLIR, temp_model_path])
|
||||||
|
|
||||||
|
# Use the generated shared library to create an execution session.
|
||||||
|
temp_shared_lib_path = os.path.join(temp_dir, "model.so")
|
||||||
|
sess = ExecutionSession(temp_shared_lib_path,
|
||||||
|
"_dyn_entry_point_main_graph")
|
||||||
|
|
||||||
|
# Generate random data as input.
|
||||||
|
inputs = []
|
||||||
|
input_names = []
|
||||||
|
initializers = list(map(lambda x: x.name, model.graph.initializer))
|
||||||
|
np.random.seed(42)
|
||||||
|
for input_proto in model.graph.input:
|
||||||
|
if input_proto.name not in initializers:
|
||||||
|
input_names.append(input_proto.name)
|
||||||
|
shape_proto = input_proto.type.tensor_type.shape
|
||||||
|
explicit_shape = []
|
||||||
|
for dim in shape_proto.dim:
|
||||||
|
assert dim.dim_value, "Can only debug models with inputs that have explicit shapes."
|
||||||
|
explicit_shape.append(dim.dim_value)
|
||||||
|
inputs.append(
|
||||||
|
np.random.uniform(-1.0, 1.0, explicit_shape).astype(np.float32))
|
||||||
|
|
||||||
|
# Run the compiled inference function on the randomly generated data.
|
||||||
|
outs = sess.run(inputs)
|
||||||
|
|
||||||
|
# Run the model with reference backend and get results.
|
||||||
|
ref_session = prepare(temp_model_path)
|
||||||
|
output_names = list(map(lambda x: x.name, model.graph.output))
|
||||||
|
input_feed = dict(zip(input_names, inputs))
|
||||||
|
ref_outs = ref_session.run(output_names, input_feed)
|
||||||
|
|
||||||
|
# For each intermediate output tensor, compare results.
|
||||||
|
for i, name in enumerate(intermediate_outputs):
|
||||||
|
print("Verifying value of {}".format(name))
|
||||||
|
np.testing.assert_array_almost_equal(ref_outs[i], outs[i], decimal=5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('model_path', type=str, help="Path to the model to debug.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(**vars(args))
|
Loading…
Reference in New Issue