diff --git a/docs/DebuggingNumericalError.md b/docs/DebuggingNumericalError.md new file mode 100644 index 0000000..5d7a78d --- /dev/null +++ b/docs/DebuggingNumericalError.md @@ -0,0 +1,26 @@ +# Debugging Numerical Error + +Use `util/debug.py` python script to debug numerical errors, when onnx-mlir-compiled inference executable produces +numerical results that are inconsistent with those produced by the training framework. +This python script will run the model through onnx-mlir and a reference backend, and compare +the intermediate results produced by these two backends layer by layer. + +## Rrerequisite +- Set `ONNX_MLIR_HOME` environment variable to be the path to + the HOME directory for onnx-mlir. The HOME directory for onnx-mlir refers to + the parent folder containing the `bin`, `lib`, etc sub-folders in which ONNX-MLIR + executables and libraries can be found. +- Install an ONNX backend, by default onnx-runtime is used as testing backend. Install by + running `pip install onnxruntime`. To use a different testing backend, simply replace code + importing onnxruntime to some other ONNX-compliant backend. + +## Usage + +`util/debug.py` supports the following command-line options: + +```bash +usage: debug.py [-h] model_path + +positional arguments: + model_path Path to the model to debug. +``` diff --git a/docs/_data/navigation.yml b/docs/_data/navigation.yml index 07986a6..222d8ef 100644 --- a/docs/_data/navigation.yml +++ b/docs/_data/navigation.yml @@ -29,5 +29,7 @@ toc: # url: /piece1.html - title: Tools subfolderitems: + - page: debug.py - Debug Numerical Errors + url: /DebuggingNumericalError.html - page: DocCheck - Handling Necessary Code Duplication url: /doc_check/ \ No newline at end of file diff --git a/src/MainUtils.cpp b/src/MainUtils.cpp index 68b46af..0e59994 100644 --- a/src/MainUtils.cpp +++ b/src/MainUtils.cpp @@ -10,12 +10,13 @@ #include #include - -#include "llvm/Support/Program.h" +#include #include "src/ExternalUtil.hpp" #include "src/MainUtils.hpp" +#include "MainUtils.hpp" + #ifdef _WIN32 #include #else diff --git a/src/Transform/LowerToLLVM.cpp b/src/Transform/LowerToLLVM.cpp index b5be56c..b03177d 100644 --- a/src/Transform/LowerToLLVM.cpp +++ b/src/Transform/LowerToLLVM.cpp @@ -387,36 +387,57 @@ public: staticInputs.emplace_back(ptrToMemRef); } - // If more than one output exists, the struct becomes a nested struct, - // the unpacking logic can be more involved, so no support for now. - assert(numOutputs == 1 && "only support 1 output tensor now."); - // Call static entry point with the memref ptrs created, and get output. - auto outputMemRefs = rewriter.create(loc, - staticEntryPointTy.getFunctionResultType(), - rewriter.getSymbolRefAttr(wrappedStaticEntryPointFuncName), - staticInputs); + auto outMemRefs = + rewriter + .create(loc, + staticEntryPointTy.getFunctionResultType(), + rewriter.getSymbolRefAttr(wrappedStaticEntryPointFuncName), + staticInputs) + .getResult(0); + auto outMemRefsType = outMemRefs.getType().dyn_cast(); + + std::vector outMemRefList; + if (numOutputs == 1) { + // If only one output tensor exists, the tensor's corresponding memref + // descriptor will be returned as is. + outMemRefList.emplace_back(outMemRefs); + } else { + // Otherwise, if multiple tensors are to be returned, the returned value + // is a struct. Multiple tensors' memref descriptors are packed into the + // same struct. So we unpack them iteratively to outMemRefList. + for (int i = 0; i < numOutputs; i++) { + auto position = rewriter.getArrayAttr({rewriter.getI64IntegerAttr(i)}); + auto type = outMemRefsType.getStructElementType(i); + auto extractOp = rewriter.create(loc, + /*res=*/type, + /*type=*/outMemRefs, + /*position=*/position); + outMemRefList.emplace_back(extractOp.getResult()); + } + } // Create wrapped output. auto wrappedOutput = callApi( rewriter, loc, apiRegistry, API::CREATE_ORDERED_DYN_MEM_REF_DICT, {}); - // Get the first memref returned, convert to a dynamic memref and store - // it in the wrapped Output. - auto outMemRef = outputMemRefs.getResult(0); - auto outMemRefTy = outMemRef.getType().dyn_cast(); - auto outMemRefRank = getRankFromMemRefType(outMemRefTy); - auto outMemRefRankVal = rewriter.create( - loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank)); - auto outDynMemRef = callApi(rewriter, loc, apiRegistry, - API::CREATE_DYN_MEM_REF, {outMemRefRankVal}); - fillDynMemRefWithMemRef( - outMemRef, outDynMemRef, rewriter, loc, apiRegistry, llvmDialect); - auto zero = rewriter.create( - loc, int32Ty, rewriter.getI32IntegerAttr(0)); - callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF, - {wrappedOutput, zero, outDynMemRef}); - + for (decltype(numOutputs) i = 0; i < outMemRefList.size(); i++) { + // Get the i-th memref returned, convert to a dynamic memref and store it + // in the wrappedOutput. + auto memRef = outMemRefList.at(i); + auto outMemRefTy = memRef.getType().dyn_cast(); + auto outMemRefRank = getRankFromMemRefType(outMemRefTy); + auto outMemRefRankVal = rewriter.create( + loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank)); + auto outDynMemRef = callApi(rewriter, loc, apiRegistry, + API::CREATE_DYN_MEM_REF, {outMemRefRankVal}); + fillDynMemRefWithMemRef( + memRef, outDynMemRef, rewriter, loc, apiRegistry, llvmDialect); + auto idx = rewriter.create( + loc, int32Ty, rewriter.getI32IntegerAttr(i)); + callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF, + {wrappedOutput, idx, outDynMemRef}); + } // Return wrapped output. rewriter.create( loc, SmallVector({wrappedOutput})); @@ -613,7 +634,6 @@ void KrnlToLLVMLoweringPass::runOnOperation() { ConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalOp(); - // target.addLegalOp(); // Lower the MemRef types to a representation in LLVM. LLVMTypeConverter typeConverter(&getContext()); @@ -626,7 +646,6 @@ void KrnlToLLVMLoweringPass::runOnOperation() { populateStdToLLVMConversionPatterns(typeConverter, patterns, /*emitCWrapperS=*/true, /*useAlignedAlloc=*/false); - patterns.insert(&getContext(), typeConverter); // Lower from the `krnl` dialect i.e. the Reshape operation. diff --git a/utils/debug.py b/utils/debug.py new file mode 100644 index 0000000..456c3f7 --- /dev/null +++ b/utils/debug.py @@ -0,0 +1,112 @@ +import os +import sys +import argparse +import onnx +import subprocess +import numpy as np +import tempfile + +from collections import OrderedDict + +# Reference backend, use onnxruntime by default +import onnxruntime +prepare = onnxruntime.InferenceSession + +if (not os.environ.get('ONNX_MLIR_HOME', None)): + raise RuntimeError( + "Environment variable ONNX_MLIR_HOME is not set, please set it to the path to " + "the HOME directory for onnx-mlir. The HOME directory for onnx-mlir refers to " + "the parent folder containing the bin, lib, etc sub-folders in which ONNX-MLIR " + "executables and libraries can be found.") + +VERBOSE = os.environ.get('VERBOSE', False) +ONNX_MLIR = os.path.join(os.environ['ONNX_MLIR_HOME'], "bin/onnx-mlir") + +# Include runtime directory in python paths, so pyruntime can be imported. +RUNTIME_DIR = os.path.join(os.environ['ONNX_MLIR_HOME'], "lib") +sys.path.append(RUNTIME_DIR) + +try: + from pyruntime import ExecutionSession +except ImportError: + raise ImportError( + "Looks like you did not build the pyruntime target, build it by running `make pyruntime`." + ) + + +def execute_commands(cmds): + if (VERBOSE): + print(" ".join(cmds)) + subprocess.run(cmds, stdout=subprocess.PIPE, check=True) + + +def extend_model_output(model, intermediate_outputs): + # onnx-mlir doesn't care about manually specified output types & shapes. + DUMMY_TENSOR_TYPE = onnx.TensorProto.FLOAT + + while (len(model.graph.output)): + model.graph.output.pop() + + for output_name in intermediate_outputs: + output_value_info = onnx.helper.make_tensor_value_info( + output_name, DUMMY_TENSOR_TYPE, None) + model.graph.output.extend([output_value_info]) + return model + + +def main(model_path): + model = onnx.load(model_path) + intermediate_outputs = sum( + [list(node.output) for node in model.graph.node], []) + intermediate_outputs = list(OrderedDict.fromkeys(intermediate_outputs)) + model = extend_model_output(model, intermediate_outputs) + + with tempfile.TemporaryDirectory() as temp_dir: + print("Temporary directory has been created at {}".format(temp_dir)) + + # Save modified model & invoke onnx-mlir to compile it. + temp_model_path = os.path.join(temp_dir, "model.onnx") + onnx.save(model, temp_model_path) + execute_commands([ONNX_MLIR, temp_model_path]) + + # Use the generated shared library to create an execution session. + temp_shared_lib_path = os.path.join(temp_dir, "model.so") + sess = ExecutionSession(temp_shared_lib_path, + "_dyn_entry_point_main_graph") + + # Generate random data as input. + inputs = [] + input_names = [] + initializers = list(map(lambda x: x.name, model.graph.initializer)) + np.random.seed(42) + for input_proto in model.graph.input: + if input_proto.name not in initializers: + input_names.append(input_proto.name) + shape_proto = input_proto.type.tensor_type.shape + explicit_shape = [] + for dim in shape_proto.dim: + assert dim.dim_value, "Can only debug models with inputs that have explicit shapes." + explicit_shape.append(dim.dim_value) + inputs.append( + np.random.uniform(-1.0, 1.0, explicit_shape).astype(np.float32)) + + # Run the compiled inference function on the randomly generated data. + outs = sess.run(inputs) + + # Run the model with reference backend and get results. + ref_session = prepare(temp_model_path) + output_names = list(map(lambda x: x.name, model.graph.output)) + input_feed = dict(zip(input_names, inputs)) + ref_outs = ref_session.run(output_names, input_feed) + + # For each intermediate output tensor, compare results. + for i, name in enumerate(intermediate_outputs): + print("Verifying value of {}".format(name)) + np.testing.assert_array_almost_equal(ref_outs[i], outs[i], decimal=5) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('model_path', type=str, help="Path to the model to debug.") + args = parser.parse_args() + main(**vars(args))