[MLIR] Fix shape inference and enable ONNF to read in MLIR files. (#367)

* Fix inference. Enable ONNF to read in MLIR files.

* Fix input of onnx or mlir models.

* Address comments.
This commit is contained in:
GHEORGHE-TEOD BERCEA 2019-11-15 13:10:41 -05:00 committed by Tian Jin
parent 0d644fab92
commit dc36fd416b
3 changed files with 72 additions and 7 deletions

View File

@ -51,9 +51,14 @@ void ONNXAddOp::inferShapes() {
// MatMul // MatMul
void ONNXMatMulOp::inferShapes() { void ONNXMatMulOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!getOperand(0)->getType().isa<RankedTensorType>() ||
!getOperand(1)->getType().isa<RankedTensorType>())
return;
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims(lhsTy.getShape()[0]); SmallVector<int64_t, 2> dims;
dims.emplace_back(lhsTy.getShape()[0]);
dims.emplace_back(rhsTy.getShape()[1]); dims.emplace_back(rhsTy.getShape()[1]);
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
} }
@ -67,9 +72,14 @@ void ONNXMatMulOp::inferShapes() {
// Gemm // Gemm
void ONNXGemmOp::inferShapes() { void ONNXGemmOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!getOperand(0)->getType().isa<RankedTensorType>() ||
!getOperand(1)->getType().isa<RankedTensorType>())
return;
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims(lhsTy.getShape()[0]); SmallVector<int64_t, 2> dims;
dims.emplace_back(lhsTy.getShape()[0]);
dims.emplace_back(rhsTy.getShape()[1]); dims.emplace_back(rhsTy.getShape()[1]);
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
} }
@ -77,9 +87,14 @@ void ONNXGemmOp::inferShapes() {
// FullGemm // FullGemm
void ONNXFullGemmOp::inferShapes() { void ONNXFullGemmOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!getOperand(0)->getType().isa<RankedTensorType>() ||
!getOperand(1)->getType().isa<RankedTensorType>())
return;
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims(lhsTy.getShape()[0]); SmallVector<int64_t, 2> dims;
dims.emplace_back(lhsTy.getShape()[0]);
dims.emplace_back(rhsTy.getShape()[1]); dims.emplace_back(rhsTy.getShape()[1]);
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
} }

View File

@ -13,8 +13,8 @@
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
#include "shape_inference_interface.hpp" #include "shape_inference_interface.hpp"
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
#include "passes.hpp" #include "passes.hpp"
@ -82,7 +82,10 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
// All operations which do not return a ranked tensor type have dynamic // All operations which do not return a ranked tensor type have dynamic
// shaped outputs. All those operation need to implement the inferShape() // shaped outputs. All those operation need to implement the inferShape()
// method. // method.
if (op->getName().getStringRef() != "onnx.add") if (op->getName().getStringRef() != "onnx.add" &&
op->getName().getStringRef() != "onnx.matmul" &&
op->getName().getStringRef() != "onnx.gemm" &&
op->getName().getStringRef() != "onnx.full_gemm")
return false; return false;
return llvm::any_of(op->getResultTypes(), return llvm::any_of(op->getResultTypes(),
[](Type result_type) { return !result_type.isa<RankedTensorType>(); }); [](Type result_type) { return !result_type.isa<RankedTensorType>(); });

View File

@ -28,6 +28,10 @@
#include <boost/program_options.hpp> #include <boost/program_options.hpp>
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/SourceMgr.h"
#include "src/builder/frontend_dialect_transformer.hpp" #include "src/builder/frontend_dialect_transformer.hpp"
#include "src/compiler/dialect/krnl/krnl_ops.hpp" #include "src/compiler/dialect/krnl/krnl_ops.hpp"
#include "src/compiler/dialect/onnx/onnx_ops.hpp" #include "src/compiler/dialect/onnx/onnx_ops.hpp"
@ -47,6 +51,28 @@
using namespace std; using namespace std;
using namespace onnf; using namespace onnf;
void LoadMLIR(string inputFilename, mlir::MLIRContext& context,
mlir::OwningModuleRef& module) {
// Handle '.mlir' input to the DLC compiler.
// The mlir format indicates that one or more of the supported
// representations are used in the file.
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
if (std::error_code EC = fileOrErr.getError()) {
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
return;
}
// Parse the input mlir.
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
module = mlir::parseSourceFile(sourceMgr, &context);
if (!module) {
llvm::errs() << "Error can't load file " << inputFilename << "\n";
return;
}
}
int main(int ac, char* av[]) { int main(int ac, char* av[]) {
namespace po = boost::program_options; namespace po = boost::program_options;
@ -57,8 +83,16 @@ int main(int ac, char* av[]) {
"onnx model file"); "onnx model file");
// clang-format on // clang-format on
// Handle command line argument with option names and positional
// command line arguments.
po::positional_options_description p;
p.add("onnx-model", -1);
po::variables_map vm; po::variables_map vm;
po::store(po::parse_command_line(ac, av, desc), vm); po::store(
po::command_line_parser(ac, av).options(desc).positional(p).run(), vm);
// TODO: allow multiple input files
assert(vm.count("onnx-model") < 2 && "At most one input file can be provided!");
if (vm.count("help")) { if (vm.count("help")) {
cout << desc << endl; cout << desc << endl;
@ -71,8 +105,21 @@ int main(int ac, char* av[]) {
mlir::MLIRContext context; mlir::MLIRContext context;
mlir::OwningModuleRef module; mlir::OwningModuleRef module;
// Decide if the input file is an ONNX model or a model specified
// in MLIR. The extension of the file is the decider.
string model_filename = vm["onnx-model"].as<string>(); string model_filename = vm["onnx-model"].as<string>();
string extension =
model_filename.substr(model_filename.find_last_of(".") + 1);
bool onnx_model_provided = (extension == "onnx");
bool mlir_model_provided = (extension == "mlir");
if (onnx_model_provided) {
ImportFrontendModelFile(model_filename, context, module); ImportFrontendModelFile(model_filename, context, module);
} else if (mlir_model_provided) {
LoadMLIR(model_filename, context, module);
} else {
assert(false && "No ONNX or MLIR models provided!");
}
mlir::PassManager pm(&context); mlir::PassManager pm(&context);
pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createShapeInferencePass());