From dc36fd416be993aba6c10f0a7c53e9bea818aeb2 Mon Sep 17 00:00:00 2001 From: GHEORGHE-TEOD BERCEA Date: Fri, 15 Nov 2019 13:10:41 -0500 Subject: [PATCH] [MLIR] Fix shape inference and enable ONNF to read in MLIR files. (#367) * Fix inference. Enable ONNF to read in MLIR files. * Fix input of onnx or mlir models. * Address comments. --- src/compiler/dialect/onnx/onnx_ops.cpp | 21 +++++++-- src/compiler/pass/shape_inference_pass.cpp | 7 ++- src/main.cpp | 51 +++++++++++++++++++++- 3 files changed, 72 insertions(+), 7 deletions(-) diff --git a/src/compiler/dialect/onnx/onnx_ops.cpp b/src/compiler/dialect/onnx/onnx_ops.cpp index bca52f1..6d79313 100644 --- a/src/compiler/dialect/onnx/onnx_ops.cpp +++ b/src/compiler/dialect/onnx/onnx_ops.cpp @@ -51,9 +51,14 @@ void ONNXAddOp::inferShapes() { // MatMul void ONNXMatMulOp::inferShapes() { + // Cannot infer shape if no shape exists. + if (!getOperand(0)->getType().isa() || + !getOperand(1)->getType().isa()) + return; auto lhsTy = getOperand(0)->getType().cast(); auto rhsTy = getOperand(1)->getType().cast(); - SmallVector dims(lhsTy.getShape()[0]); + SmallVector dims; + dims.emplace_back(lhsTy.getShape()[0]); dims.emplace_back(rhsTy.getShape()[1]); getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); } @@ -67,9 +72,14 @@ void ONNXMatMulOp::inferShapes() { // Gemm void ONNXGemmOp::inferShapes() { + // Cannot infer shape if no shape exists. + if (!getOperand(0)->getType().isa() || + !getOperand(1)->getType().isa()) + return; auto lhsTy = getOperand(0)->getType().cast(); auto rhsTy = getOperand(1)->getType().cast(); - SmallVector dims(lhsTy.getShape()[0]); + SmallVector dims; + dims.emplace_back(lhsTy.getShape()[0]); dims.emplace_back(rhsTy.getShape()[1]); getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); } @@ -77,9 +87,14 @@ void ONNXGemmOp::inferShapes() { // FullGemm void ONNXFullGemmOp::inferShapes() { + // Cannot infer shape if no shape exists. + if (!getOperand(0)->getType().isa() || + !getOperand(1)->getType().isa()) + return; auto lhsTy = getOperand(0)->getType().cast(); auto rhsTy = getOperand(1)->getType().cast(); - SmallVector dims(lhsTy.getShape()[0]); + SmallVector dims; + dims.emplace_back(lhsTy.getShape()[0]); dims.emplace_back(rhsTy.getShape()[1]); getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); } diff --git a/src/compiler/pass/shape_inference_pass.cpp b/src/compiler/pass/shape_inference_pass.cpp index cf556df..78311bf 100644 --- a/src/compiler/pass/shape_inference_pass.cpp +++ b/src/compiler/pass/shape_inference_pass.cpp @@ -13,8 +13,8 @@ #include "llvm/Support/raw_ostream.h" #include "mlir/Pass/Pass.h" -#include "src/compiler/dialect/onnx/onnx_ops.hpp" #include "shape_inference_interface.hpp" +#include "src/compiler/dialect/onnx/onnx_ops.hpp" #include "passes.hpp" @@ -82,7 +82,10 @@ class ShapeInferencePass : public mlir::FunctionPass { // All operations which do not return a ranked tensor type have dynamic // shaped outputs. All those operation need to implement the inferShape() // method. - if (op->getName().getStringRef() != "onnx.add") + if (op->getName().getStringRef() != "onnx.add" && + op->getName().getStringRef() != "onnx.matmul" && + op->getName().getStringRef() != "onnx.gemm" && + op->getName().getStringRef() != "onnx.full_gemm") return false; return llvm::any_of(op->getResultTypes(), [](Type result_type) { return !result_type.isa(); }); diff --git a/src/main.cpp b/src/main.cpp index 0a8beb8..e0e7045 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -28,6 +28,10 @@ #include +#include "llvm/Support/FileUtilities.h" +#include "llvm/Support/Regex.h" +#include "llvm/Support/SourceMgr.h" + #include "src/builder/frontend_dialect_transformer.hpp" #include "src/compiler/dialect/krnl/krnl_ops.hpp" #include "src/compiler/dialect/onnx/onnx_ops.hpp" @@ -47,6 +51,28 @@ using namespace std; using namespace onnf; +void LoadMLIR(string inputFilename, mlir::MLIRContext& context, + mlir::OwningModuleRef& module) { + // Handle '.mlir' input to the DLC compiler. + // The mlir format indicates that one or more of the supported + // representations are used in the file. + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (std::error_code EC = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << EC.message() << "\n"; + return; + } + + // Parse the input mlir. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + module = mlir::parseSourceFile(sourceMgr, &context); + if (!module) { + llvm::errs() << "Error can't load file " << inputFilename << "\n"; + return; + } +} + int main(int ac, char* av[]) { namespace po = boost::program_options; @@ -57,8 +83,16 @@ int main(int ac, char* av[]) { "onnx model file"); // clang-format on + // Handle command line argument with option names and positional + // command line arguments. + po::positional_options_description p; + p.add("onnx-model", -1); po::variables_map vm; - po::store(po::parse_command_line(ac, av, desc), vm); + po::store( + po::command_line_parser(ac, av).options(desc).positional(p).run(), vm); + + // TODO: allow multiple input files + assert(vm.count("onnx-model") < 2 && "At most one input file can be provided!"); if (vm.count("help")) { cout << desc << endl; @@ -71,8 +105,21 @@ int main(int ac, char* av[]) { mlir::MLIRContext context; mlir::OwningModuleRef module; + // Decide if the input file is an ONNX model or a model specified + // in MLIR. The extension of the file is the decider. string model_filename = vm["onnx-model"].as(); - ImportFrontendModelFile(model_filename, context, module); + string extension = + model_filename.substr(model_filename.find_last_of(".") + 1); + bool onnx_model_provided = (extension == "onnx"); + bool mlir_model_provided = (extension == "mlir"); + + if (onnx_model_provided) { + ImportFrontendModelFile(model_filename, context, module); + } else if (mlir_model_provided) { + LoadMLIR(model_filename, context, module); + } else { + assert(false && "No ONNX or MLIR models provided!"); + } mlir::PassManager pm(&context); pm.addPass(mlir::createShapeInferencePass());