[MLIR] Fix shape inference and enable ONNF to read in MLIR files. (#367)
* Fix inference. Enable ONNF to read in MLIR files. * Fix input of onnx or mlir models. * Address comments.
This commit is contained in:
parent
0d644fab92
commit
dc36fd416b
|
@ -51,9 +51,14 @@ void ONNXAddOp::inferShapes() {
|
|||
// MatMul
|
||||
|
||||
void ONNXMatMulOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!getOperand(0)->getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1)->getType().isa<RankedTensorType>())
|
||||
return;
|
||||
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t, 2> dims(lhsTy.getShape()[0]);
|
||||
SmallVector<int64_t, 2> dims;
|
||||
dims.emplace_back(lhsTy.getShape()[0]);
|
||||
dims.emplace_back(rhsTy.getShape()[1]);
|
||||
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
}
|
||||
|
@ -67,9 +72,14 @@ void ONNXMatMulOp::inferShapes() {
|
|||
// Gemm
|
||||
|
||||
void ONNXGemmOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!getOperand(0)->getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1)->getType().isa<RankedTensorType>())
|
||||
return;
|
||||
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t, 2> dims(lhsTy.getShape()[0]);
|
||||
SmallVector<int64_t, 2> dims;
|
||||
dims.emplace_back(lhsTy.getShape()[0]);
|
||||
dims.emplace_back(rhsTy.getShape()[1]);
|
||||
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
}
|
||||
|
@ -77,9 +87,14 @@ void ONNXGemmOp::inferShapes() {
|
|||
// FullGemm
|
||||
|
||||
void ONNXFullGemmOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!getOperand(0)->getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1)->getType().isa<RankedTensorType>())
|
||||
return;
|
||||
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t, 2> dims(lhsTy.getShape()[0]);
|
||||
SmallVector<int64_t, 2> dims;
|
||||
dims.emplace_back(lhsTy.getShape()[0]);
|
||||
dims.emplace_back(rhsTy.getShape()[1]);
|
||||
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
}
|
||||
|
|
|
@ -13,8 +13,8 @@
|
|||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||
#include "shape_inference_interface.hpp"
|
||||
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||
|
||||
#include "passes.hpp"
|
||||
|
||||
|
@ -82,7 +82,10 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
|||
// All operations which do not return a ranked tensor type have dynamic
|
||||
// shaped outputs. All those operation need to implement the inferShape()
|
||||
// method.
|
||||
if (op->getName().getStringRef() != "onnx.add")
|
||||
if (op->getName().getStringRef() != "onnx.add" &&
|
||||
op->getName().getStringRef() != "onnx.matmul" &&
|
||||
op->getName().getStringRef() != "onnx.gemm" &&
|
||||
op->getName().getStringRef() != "onnx.full_gemm")
|
||||
return false;
|
||||
return llvm::any_of(op->getResultTypes(),
|
||||
[](Type result_type) { return !result_type.isa<RankedTensorType>(); });
|
||||
|
|
49
src/main.cpp
49
src/main.cpp
|
@ -28,6 +28,10 @@
|
|||
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "llvm/Support/FileUtilities.h"
|
||||
#include "llvm/Support/Regex.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
|
||||
#include "src/builder/frontend_dialect_transformer.hpp"
|
||||
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
||||
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||
|
@ -47,6 +51,28 @@
|
|||
using namespace std;
|
||||
using namespace onnf;
|
||||
|
||||
void LoadMLIR(string inputFilename, mlir::MLIRContext& context,
|
||||
mlir::OwningModuleRef& module) {
|
||||
// Handle '.mlir' input to the DLC compiler.
|
||||
// The mlir format indicates that one or more of the supported
|
||||
// representations are used in the file.
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
||||
if (std::error_code EC = fileOrErr.getError()) {
|
||||
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse the input mlir.
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
|
||||
module = mlir::parseSourceFile(sourceMgr, &context);
|
||||
if (!module) {
|
||||
llvm::errs() << "Error can't load file " << inputFilename << "\n";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int ac, char* av[]) {
|
||||
namespace po = boost::program_options;
|
||||
|
||||
|
@ -57,8 +83,16 @@ int main(int ac, char* av[]) {
|
|||
"onnx model file");
|
||||
// clang-format on
|
||||
|
||||
// Handle command line argument with option names and positional
|
||||
// command line arguments.
|
||||
po::positional_options_description p;
|
||||
p.add("onnx-model", -1);
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(ac, av, desc), vm);
|
||||
po::store(
|
||||
po::command_line_parser(ac, av).options(desc).positional(p).run(), vm);
|
||||
|
||||
// TODO: allow multiple input files
|
||||
assert(vm.count("onnx-model") < 2 && "At most one input file can be provided!");
|
||||
|
||||
if (vm.count("help")) {
|
||||
cout << desc << endl;
|
||||
|
@ -71,8 +105,21 @@ int main(int ac, char* av[]) {
|
|||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module;
|
||||
|
||||
// Decide if the input file is an ONNX model or a model specified
|
||||
// in MLIR. The extension of the file is the decider.
|
||||
string model_filename = vm["onnx-model"].as<string>();
|
||||
string extension =
|
||||
model_filename.substr(model_filename.find_last_of(".") + 1);
|
||||
bool onnx_model_provided = (extension == "onnx");
|
||||
bool mlir_model_provided = (extension == "mlir");
|
||||
|
||||
if (onnx_model_provided) {
|
||||
ImportFrontendModelFile(model_filename, context, module);
|
||||
} else if (mlir_model_provided) {
|
||||
LoadMLIR(model_filename, context, module);
|
||||
} else {
|
||||
assert(false && "No ONNX or MLIR models provided!");
|
||||
}
|
||||
|
||||
mlir::PassManager pm(&context);
|
||||
pm.addPass(mlir::createShapeInferencePass());
|
||||
|
|
Loading…
Reference in New Issue