[MLIR] Fix shape inference and enable ONNF to read in MLIR files. (#367)
* Fix inference. Enable ONNF to read in MLIR files. * Fix input of onnx or mlir models. * Address comments.
This commit is contained in:
parent
0d644fab92
commit
dc36fd416b
|
@ -51,9 +51,14 @@ void ONNXAddOp::inferShapes() {
|
||||||
// MatMul
|
// MatMul
|
||||||
|
|
||||||
void ONNXMatMulOp::inferShapes() {
|
void ONNXMatMulOp::inferShapes() {
|
||||||
|
// Cannot infer shape if no shape exists.
|
||||||
|
if (!getOperand(0)->getType().isa<RankedTensorType>() ||
|
||||||
|
!getOperand(1)->getType().isa<RankedTensorType>())
|
||||||
|
return;
|
||||||
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||||
SmallVector<int64_t, 2> dims(lhsTy.getShape()[0]);
|
SmallVector<int64_t, 2> dims;
|
||||||
|
dims.emplace_back(lhsTy.getShape()[0]);
|
||||||
dims.emplace_back(rhsTy.getShape()[1]);
|
dims.emplace_back(rhsTy.getShape()[1]);
|
||||||
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
@ -67,9 +72,14 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
// Gemm
|
// Gemm
|
||||||
|
|
||||||
void ONNXGemmOp::inferShapes() {
|
void ONNXGemmOp::inferShapes() {
|
||||||
|
// Cannot infer shape if no shape exists.
|
||||||
|
if (!getOperand(0)->getType().isa<RankedTensorType>() ||
|
||||||
|
!getOperand(1)->getType().isa<RankedTensorType>())
|
||||||
|
return;
|
||||||
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||||
SmallVector<int64_t, 2> dims(lhsTy.getShape()[0]);
|
SmallVector<int64_t, 2> dims;
|
||||||
|
dims.emplace_back(lhsTy.getShape()[0]);
|
||||||
dims.emplace_back(rhsTy.getShape()[1]);
|
dims.emplace_back(rhsTy.getShape()[1]);
|
||||||
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
@ -77,9 +87,14 @@ void ONNXGemmOp::inferShapes() {
|
||||||
// FullGemm
|
// FullGemm
|
||||||
|
|
||||||
void ONNXFullGemmOp::inferShapes() {
|
void ONNXFullGemmOp::inferShapes() {
|
||||||
|
// Cannot infer shape if no shape exists.
|
||||||
|
if (!getOperand(0)->getType().isa<RankedTensorType>() ||
|
||||||
|
!getOperand(1)->getType().isa<RankedTensorType>())
|
||||||
|
return;
|
||||||
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||||
SmallVector<int64_t, 2> dims(lhsTy.getShape()[0]);
|
SmallVector<int64_t, 2> dims;
|
||||||
|
dims.emplace_back(lhsTy.getShape()[0]);
|
||||||
dims.emplace_back(rhsTy.getShape()[1]);
|
dims.emplace_back(rhsTy.getShape()[1]);
|
||||||
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,8 +13,8 @@
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
|
||||||
#include "shape_inference_interface.hpp"
|
#include "shape_inference_interface.hpp"
|
||||||
|
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||||
|
|
||||||
#include "passes.hpp"
|
#include "passes.hpp"
|
||||||
|
|
||||||
|
@ -82,7 +82,10 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
||||||
// All operations which do not return a ranked tensor type have dynamic
|
// All operations which do not return a ranked tensor type have dynamic
|
||||||
// shaped outputs. All those operation need to implement the inferShape()
|
// shaped outputs. All those operation need to implement the inferShape()
|
||||||
// method.
|
// method.
|
||||||
if (op->getName().getStringRef() != "onnx.add")
|
if (op->getName().getStringRef() != "onnx.add" &&
|
||||||
|
op->getName().getStringRef() != "onnx.matmul" &&
|
||||||
|
op->getName().getStringRef() != "onnx.gemm" &&
|
||||||
|
op->getName().getStringRef() != "onnx.full_gemm")
|
||||||
return false;
|
return false;
|
||||||
return llvm::any_of(op->getResultTypes(),
|
return llvm::any_of(op->getResultTypes(),
|
||||||
[](Type result_type) { return !result_type.isa<RankedTensorType>(); });
|
[](Type result_type) { return !result_type.isa<RankedTensorType>(); });
|
||||||
|
|
49
src/main.cpp
49
src/main.cpp
|
@ -28,6 +28,10 @@
|
||||||
|
|
||||||
#include <boost/program_options.hpp>
|
#include <boost/program_options.hpp>
|
||||||
|
|
||||||
|
#include "llvm/Support/FileUtilities.h"
|
||||||
|
#include "llvm/Support/Regex.h"
|
||||||
|
#include "llvm/Support/SourceMgr.h"
|
||||||
|
|
||||||
#include "src/builder/frontend_dialect_transformer.hpp"
|
#include "src/builder/frontend_dialect_transformer.hpp"
|
||||||
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
||||||
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||||
|
@ -47,6 +51,28 @@
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace onnf;
|
using namespace onnf;
|
||||||
|
|
||||||
|
void LoadMLIR(string inputFilename, mlir::MLIRContext& context,
|
||||||
|
mlir::OwningModuleRef& module) {
|
||||||
|
// Handle '.mlir' input to the DLC compiler.
|
||||||
|
// The mlir format indicates that one or more of the supported
|
||||||
|
// representations are used in the file.
|
||||||
|
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||||
|
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
||||||
|
if (std::error_code EC = fileOrErr.getError()) {
|
||||||
|
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the input mlir.
|
||||||
|
llvm::SourceMgr sourceMgr;
|
||||||
|
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
|
||||||
|
module = mlir::parseSourceFile(sourceMgr, &context);
|
||||||
|
if (!module) {
|
||||||
|
llvm::errs() << "Error can't load file " << inputFilename << "\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int main(int ac, char* av[]) {
|
int main(int ac, char* av[]) {
|
||||||
namespace po = boost::program_options;
|
namespace po = boost::program_options;
|
||||||
|
|
||||||
|
@ -57,8 +83,16 @@ int main(int ac, char* av[]) {
|
||||||
"onnx model file");
|
"onnx model file");
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
|
// Handle command line argument with option names and positional
|
||||||
|
// command line arguments.
|
||||||
|
po::positional_options_description p;
|
||||||
|
p.add("onnx-model", -1);
|
||||||
po::variables_map vm;
|
po::variables_map vm;
|
||||||
po::store(po::parse_command_line(ac, av, desc), vm);
|
po::store(
|
||||||
|
po::command_line_parser(ac, av).options(desc).positional(p).run(), vm);
|
||||||
|
|
||||||
|
// TODO: allow multiple input files
|
||||||
|
assert(vm.count("onnx-model") < 2 && "At most one input file can be provided!");
|
||||||
|
|
||||||
if (vm.count("help")) {
|
if (vm.count("help")) {
|
||||||
cout << desc << endl;
|
cout << desc << endl;
|
||||||
|
@ -71,8 +105,21 @@ int main(int ac, char* av[]) {
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
mlir::OwningModuleRef module;
|
mlir::OwningModuleRef module;
|
||||||
|
|
||||||
|
// Decide if the input file is an ONNX model or a model specified
|
||||||
|
// in MLIR. The extension of the file is the decider.
|
||||||
string model_filename = vm["onnx-model"].as<string>();
|
string model_filename = vm["onnx-model"].as<string>();
|
||||||
|
string extension =
|
||||||
|
model_filename.substr(model_filename.find_last_of(".") + 1);
|
||||||
|
bool onnx_model_provided = (extension == "onnx");
|
||||||
|
bool mlir_model_provided = (extension == "mlir");
|
||||||
|
|
||||||
|
if (onnx_model_provided) {
|
||||||
ImportFrontendModelFile(model_filename, context, module);
|
ImportFrontendModelFile(model_filename, context, module);
|
||||||
|
} else if (mlir_model_provided) {
|
||||||
|
LoadMLIR(model_filename, context, module);
|
||||||
|
} else {
|
||||||
|
assert(false && "No ONNX or MLIR models provided!");
|
||||||
|
}
|
||||||
|
|
||||||
mlir::PassManager pm(&context);
|
mlir::PassManager pm(&context);
|
||||||
pm.addPass(mlir::createShapeInferencePass());
|
pm.addPass(mlir::createShapeInferencePass());
|
||||||
|
|
Loading…
Reference in New Issue