Support model debugging (#138)

* Call llc, ld from within onnx-mlir.

* Rename EmitLLVMBC -> EmitLib., reorder header files

* Checkpoint, debug.py works.

* Automatically generate inputs in debug.py.

* Use float.

* Fix merge conflict, remove RapidCheck from this patch.

* Remove submodule rapidcheck properly.

* Reformat code.

* More comments.

* Add documentation.

* Add documentation to navigation.

* Account for the fact that some initializers may also appear as input.
This commit is contained in:
Tian Jin 2020-05-21 13:02:48 +08:00 committed by GitHub
parent be62d85a32
commit a270af5ce0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 188 additions and 28 deletions

View File

@ -0,0 +1,26 @@
# Debugging Numerical Error
Use `util/debug.py` python script to debug numerical errors, when onnx-mlir-compiled inference executable produces
numerical results that are inconsistent with those produced by the training framework.
This python script will run the model through onnx-mlir and a reference backend, and compare
the intermediate results produced by these two backends layer by layer.
## Rrerequisite
- Set `ONNX_MLIR_HOME` environment variable to be the path to
the HOME directory for onnx-mlir. The HOME directory for onnx-mlir refers to
the parent folder containing the `bin`, `lib`, etc sub-folders in which ONNX-MLIR
executables and libraries can be found.
- Install an ONNX backend, by default onnx-runtime is used as testing backend. Install by
running `pip install onnxruntime`. To use a different testing backend, simply replace code
importing onnxruntime to some other ONNX-compliant backend.
## Usage
`util/debug.py` supports the following command-line options:
```bash
usage: debug.py [-h] model_path
positional arguments:
model_path Path to the model to debug.
```

View File

@ -29,5 +29,7 @@ toc:
# url: /piece1.html # url: /piece1.html
- title: Tools - title: Tools
subfolderitems: subfolderitems:
- page: debug.py - Debug Numerical Errors
url: /DebuggingNumericalError.html
- page: DocCheck - Handling Necessary Code Duplication - page: DocCheck - Handling Necessary Code Duplication
url: /doc_check/ url: /doc_check/

View File

@ -10,12 +10,13 @@
#include <cstdio> #include <cstdio>
#include <fcntl.h> #include <fcntl.h>
#include <llvm/Support/Program.h>
#include "llvm/Support/Program.h"
#include "src/ExternalUtil.hpp" #include "src/ExternalUtil.hpp"
#include "src/MainUtils.hpp" #include "src/MainUtils.hpp"
#include "MainUtils.hpp"
#ifdef _WIN32 #ifdef _WIN32
#include <io.h> #include <io.h>
#else #else

View File

@ -387,36 +387,57 @@ public:
staticInputs.emplace_back(ptrToMemRef); staticInputs.emplace_back(ptrToMemRef);
} }
// If more than one output exists, the struct becomes a nested struct,
// the unpacking logic can be more involved, so no support for now.
assert(numOutputs == 1 && "only support 1 output tensor now.");
// Call static entry point with the memref ptrs created, and get output. // Call static entry point with the memref ptrs created, and get output.
auto outputMemRefs = rewriter.create<LLVM::CallOp>(loc, auto outMemRefs =
staticEntryPointTy.getFunctionResultType(), rewriter
rewriter.getSymbolRefAttr(wrappedStaticEntryPointFuncName), .create<LLVM::CallOp>(loc,
staticInputs); staticEntryPointTy.getFunctionResultType(),
rewriter.getSymbolRefAttr(wrappedStaticEntryPointFuncName),
staticInputs)
.getResult(0);
auto outMemRefsType = outMemRefs.getType().dyn_cast<LLVMType>();
std::vector<mlir::Value> outMemRefList;
if (numOutputs == 1) {
// If only one output tensor exists, the tensor's corresponding memref
// descriptor will be returned as is.
outMemRefList.emplace_back(outMemRefs);
} else {
// Otherwise, if multiple tensors are to be returned, the returned value
// is a struct. Multiple tensors' memref descriptors are packed into the
// same struct. So we unpack them iteratively to outMemRefList.
for (int i = 0; i < numOutputs; i++) {
auto position = rewriter.getArrayAttr({rewriter.getI64IntegerAttr(i)});
auto type = outMemRefsType.getStructElementType(i);
auto extractOp = rewriter.create<LLVM::ExtractValueOp>(loc,
/*res=*/type,
/*type=*/outMemRefs,
/*position=*/position);
outMemRefList.emplace_back(extractOp.getResult());
}
}
// Create wrapped output. // Create wrapped output.
auto wrappedOutput = callApi( auto wrappedOutput = callApi(
rewriter, loc, apiRegistry, API::CREATE_ORDERED_DYN_MEM_REF_DICT, {}); rewriter, loc, apiRegistry, API::CREATE_ORDERED_DYN_MEM_REF_DICT, {});
// Get the first memref returned, convert to a dynamic memref and store for (decltype(numOutputs) i = 0; i < outMemRefList.size(); i++) {
// it in the wrapped Output. // Get the i-th memref returned, convert to a dynamic memref and store it
auto outMemRef = outputMemRefs.getResult(0); // in the wrappedOutput.
auto outMemRefTy = outMemRef.getType().dyn_cast<LLVMType>(); auto memRef = outMemRefList.at(i);
auto outMemRefRank = getRankFromMemRefType(outMemRefTy); auto outMemRefTy = memRef.getType().dyn_cast<LLVMType>();
auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>( auto outMemRefRank = getRankFromMemRefType(outMemRefTy);
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank)); auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(
auto outDynMemRef = callApi(rewriter, loc, apiRegistry, loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
API::CREATE_DYN_MEM_REF, {outMemRefRankVal}); auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
fillDynMemRefWithMemRef( API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
outMemRef, outDynMemRef, rewriter, loc, apiRegistry, llvmDialect); fillDynMemRefWithMemRef(
auto zero = rewriter.create<LLVM::ConstantOp>( memRef, outDynMemRef, rewriter, loc, apiRegistry, llvmDialect);
loc, int32Ty, rewriter.getI32IntegerAttr(0)); auto idx = rewriter.create<LLVM::ConstantOp>(
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF, loc, int32Ty, rewriter.getI32IntegerAttr(i));
{wrappedOutput, zero, outDynMemRef}); callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
{wrappedOutput, idx, outDynMemRef});
}
// Return wrapped output. // Return wrapped output.
rewriter.create<LLVM::ReturnOp>( rewriter.create<LLVM::ReturnOp>(
loc, SmallVector<Value, 1>({wrappedOutput})); loc, SmallVector<Value, 1>({wrappedOutput}));
@ -613,7 +634,6 @@ void KrnlToLLVMLoweringPass::runOnOperation() {
ConversionTarget target(getContext()); ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>(); target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
// target.addLegalOp<KrnlEntryPointOp>();
// Lower the MemRef types to a representation in LLVM. // Lower the MemRef types to a representation in LLVM.
LLVMTypeConverter typeConverter(&getContext()); LLVMTypeConverter typeConverter(&getContext());
@ -626,7 +646,6 @@ void KrnlToLLVMLoweringPass::runOnOperation() {
populateStdToLLVMConversionPatterns(typeConverter, patterns, populateStdToLLVMConversionPatterns(typeConverter, patterns,
/*emitCWrapperS=*/true, /*emitCWrapperS=*/true,
/*useAlignedAlloc=*/false); /*useAlignedAlloc=*/false);
patterns.insert<KrnlGlobalOpLowering>(&getContext(), typeConverter); patterns.insert<KrnlGlobalOpLowering>(&getContext(), typeConverter);
// Lower from the `krnl` dialect i.e. the Reshape operation. // Lower from the `krnl` dialect i.e. the Reshape operation.

112
utils/debug.py Normal file
View File

@ -0,0 +1,112 @@
import os
import sys
import argparse
import onnx
import subprocess
import numpy as np
import tempfile
from collections import OrderedDict
# Reference backend, use onnxruntime by default
import onnxruntime
prepare = onnxruntime.InferenceSession
if (not os.environ.get('ONNX_MLIR_HOME', None)):
raise RuntimeError(
"Environment variable ONNX_MLIR_HOME is not set, please set it to the path to "
"the HOME directory for onnx-mlir. The HOME directory for onnx-mlir refers to "
"the parent folder containing the bin, lib, etc sub-folders in which ONNX-MLIR "
"executables and libraries can be found.")
VERBOSE = os.environ.get('VERBOSE', False)
ONNX_MLIR = os.path.join(os.environ['ONNX_MLIR_HOME'], "bin/onnx-mlir")
# Include runtime directory in python paths, so pyruntime can be imported.
RUNTIME_DIR = os.path.join(os.environ['ONNX_MLIR_HOME'], "lib")
sys.path.append(RUNTIME_DIR)
try:
from pyruntime import ExecutionSession
except ImportError:
raise ImportError(
"Looks like you did not build the pyruntime target, build it by running `make pyruntime`."
)
def execute_commands(cmds):
if (VERBOSE):
print(" ".join(cmds))
subprocess.run(cmds, stdout=subprocess.PIPE, check=True)
def extend_model_output(model, intermediate_outputs):
# onnx-mlir doesn't care about manually specified output types & shapes.
DUMMY_TENSOR_TYPE = onnx.TensorProto.FLOAT
while (len(model.graph.output)):
model.graph.output.pop()
for output_name in intermediate_outputs:
output_value_info = onnx.helper.make_tensor_value_info(
output_name, DUMMY_TENSOR_TYPE, None)
model.graph.output.extend([output_value_info])
return model
def main(model_path):
model = onnx.load(model_path)
intermediate_outputs = sum(
[list(node.output) for node in model.graph.node], [])
intermediate_outputs = list(OrderedDict.fromkeys(intermediate_outputs))
model = extend_model_output(model, intermediate_outputs)
with tempfile.TemporaryDirectory() as temp_dir:
print("Temporary directory has been created at {}".format(temp_dir))
# Save modified model & invoke onnx-mlir to compile it.
temp_model_path = os.path.join(temp_dir, "model.onnx")
onnx.save(model, temp_model_path)
execute_commands([ONNX_MLIR, temp_model_path])
# Use the generated shared library to create an execution session.
temp_shared_lib_path = os.path.join(temp_dir, "model.so")
sess = ExecutionSession(temp_shared_lib_path,
"_dyn_entry_point_main_graph")
# Generate random data as input.
inputs = []
input_names = []
initializers = list(map(lambda x: x.name, model.graph.initializer))
np.random.seed(42)
for input_proto in model.graph.input:
if input_proto.name not in initializers:
input_names.append(input_proto.name)
shape_proto = input_proto.type.tensor_type.shape
explicit_shape = []
for dim in shape_proto.dim:
assert dim.dim_value, "Can only debug models with inputs that have explicit shapes."
explicit_shape.append(dim.dim_value)
inputs.append(
np.random.uniform(-1.0, 1.0, explicit_shape).astype(np.float32))
# Run the compiled inference function on the randomly generated data.
outs = sess.run(inputs)
# Run the model with reference backend and get results.
ref_session = prepare(temp_model_path)
output_names = list(map(lambda x: x.name, model.graph.output))
input_feed = dict(zip(input_names, inputs))
ref_outs = ref_session.run(output_names, input_feed)
# For each intermediate output tensor, compare results.
for i, name in enumerate(intermediate_outputs):
print("Verifying value of {}".format(name))
np.testing.assert_array_almost_equal(ref_outs[i], outs[i], decimal=5)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('model_path', type=str, help="Path to the model to debug.")
args = parser.parse_args()
main(**vars(args))