emit SGIR based on onnx model, following the toy/Ch2 example (#345)
* emit SGIR based on onnx model, following the toy/Ch2 example * fix 1) code style 2) multiple output of a node * Update sgir.cpp
This commit is contained in:
parent
63427030ca
commit
50ac6e7bce
|
@ -6,6 +6,8 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <regex>
|
||||
#include <tuple>
|
||||
#include <numeric>
|
||||
|
||||
#include "mlir/Analysis/Verifier.h"
|
||||
|
@ -18,51 +20,192 @@
|
|||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/ScopedHashTable.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "sgir.hpp"
|
||||
|
||||
using llvm::cast;
|
||||
using llvm::dyn_cast;
|
||||
using llvm::isa;
|
||||
using llvm::ScopedHashTableScope;
|
||||
using llvm::SmallVector;
|
||||
using llvm::StringRef;
|
||||
using llvm::Twine;
|
||||
|
||||
namespace onnf {
|
||||
namespace {
|
||||
|
||||
struct OnnxOnnfSymbolMapping {
|
||||
/*!
|
||||
* Get MLIR tensor by onnx tensor name.
|
||||
* @param name onnx tensor name.
|
||||
* @return onnf tensor corresponding to `name`.
|
||||
*/
|
||||
mlir::Value* GetTensorByOnnxName(std::string name) {
|
||||
return onnx_name2onnf_tensor.at(legalize_name(name));
|
||||
}
|
||||
|
||||
/*!
|
||||
* Add a new mapping from onnx tensor name to MLIR symbol.
|
||||
* @param name onnx tensor name.
|
||||
* @param tensor MLIR Value* pointer.
|
||||
*/
|
||||
void AddMapping(std::string name, mlir::Value* tensor) {
|
||||
onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
|
||||
}
|
||||
|
||||
|
||||
bool ContainKey(std::string name) {
|
||||
return onnx_name2onnf_tensor.count(name) != 0;
|
||||
}
|
||||
|
||||
private:
|
||||
/*!
|
||||
* mapping from onnx tensor names to MLIR tensor.
|
||||
*/
|
||||
std::map<std::string, mlir::Value*> onnx_name2onnf_tensor;
|
||||
};
|
||||
|
||||
class SGIRGenImpl {
|
||||
public :
|
||||
SGIRGenImpl(mlir::MLIRContext &context)
|
||||
: context(context), builder(&context) {}
|
||||
: context_(context), builder_(&context) {
|
||||
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
}
|
||||
|
||||
mlir::ModuleOp mlirGen() {
|
||||
theModule = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
return theModule;
|
||||
mlir::ModuleOp ImportModel(onnx::ModelProto model) {
|
||||
ImportGraph(model.graph());
|
||||
return module_;
|
||||
}
|
||||
|
||||
private:
|
||||
mlir::MLIRContext &context;
|
||||
mlir::ModuleOp theModule;
|
||||
mlir::OpBuilder builder;
|
||||
mlir::MLIRContext &context_;
|
||||
mlir::ModuleOp module_;
|
||||
mlir::OpBuilder builder_;
|
||||
// mapping between string name and symbol
|
||||
OnnxOnnfSymbolMapping sgir_symbols_;
|
||||
|
||||
} ;
|
||||
mlir::Location UnknownLoc() {
|
||||
return mlir::UnknownLoc::get(&context_);
|
||||
}
|
||||
|
||||
mlir::Type TypeConvert(onnx::TensorProto_DataType intype) {
|
||||
return builder_.getF32Type();
|
||||
}
|
||||
|
||||
void ImportInputTensor(onnx::ValueInfoProto& input) {
|
||||
std::vector<int64_t> dims;
|
||||
auto shape_proto = input.type().tensor_type().shape();
|
||||
auto input_tensor_legalized_name = legalize_name(input.name());
|
||||
for (int i = 0; i < shape_proto.dim_size(); i++) {
|
||||
if (shape_proto.dim()[i].dim_value()) {
|
||||
int dim_numeric_size = shape_proto.dim()[i].dim_value();
|
||||
if (dim_numeric_size > 0) {
|
||||
dims.push_back(dim_numeric_size);
|
||||
}else { // If dim_value < 0, then dim is parametric.
|
||||
//TODO Verify the unknown dim size in MLIR
|
||||
dims.push_back(-1);
|
||||
}
|
||||
} else {
|
||||
//TODO How to represent variable length
|
||||
dims.push_back(-1);
|
||||
}
|
||||
}
|
||||
if (!sgir_symbols_.ContainKey(input_tensor_legalized_name)) {
|
||||
mlir::Type elementType = TypeConvert(input.type().tensor_type().elem_type());
|
||||
llvm::ArrayRef<int64_t> llvmdimsAR(dims.data(), dims.size());
|
||||
auto dataType = builder_.getTensorType(llvmdimsAR, elementType);
|
||||
mlir::OperationState result(UnknownLoc(), "sgir.input "+input_tensor_legalized_name);
|
||||
result.addTypes(dataType);
|
||||
auto op = builder_.createOperation(result);
|
||||
auto value = op->getResult(0);
|
||||
sgir_symbols_.AddMapping(input_tensor_legalized_name, value);
|
||||
} else {
|
||||
//TODO Should not happen
|
||||
}
|
||||
}
|
||||
|
||||
void ImportNode(onnx::NodeProto node) {
|
||||
std::vector<mlir::Value*> inputs;
|
||||
for (auto item : node.input()) {
|
||||
if (sgir_symbols_.ContainKey(legalize_name(item))) {
|
||||
inputs.push_back(sgir_symbols_.GetTensorByOnnxName(item));
|
||||
}
|
||||
}
|
||||
mlir::OperationState result(UnknownLoc(), "SGIR."+node.op_type());
|
||||
for (auto item : node.output()) {
|
||||
result.addTypes(builder_.getTensorType(builder_.getF32Type()));
|
||||
}
|
||||
result.addOperands(inputs);
|
||||
auto op = builder_.createOperation(result);
|
||||
for (int i=0 ; i< node.output().size(); i++) {
|
||||
auto r = builder_.createOperation(result)->getResult(i);
|
||||
sgir_symbols_.AddMapping(legalize_name(node.output()[i]), r);
|
||||
}
|
||||
|
||||
//TODO more info from node: attributes
|
||||
}
|
||||
|
||||
void ImportOutputTensor(onnx::ValueInfoProto& output) {
|
||||
if(sgir_symbols_.ContainKey(legalize_name(output.name()))) {
|
||||
mlir::OperationState result(UnknownLoc(), "sgir.output "+output.name());
|
||||
result.addTypes(builder_.getTensorType(builder_.getF32Type()));
|
||||
result.addOperands(sgir_symbols_.GetTensorByOnnxName(output.name()));
|
||||
builder_.createOperation(result);
|
||||
} else {
|
||||
//TODO: Why not in the symbol table? something is wrong
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void ImportGraph(onnx::GraphProto graph) {
|
||||
//create a function for the graph
|
||||
//TODO:
|
||||
// * get name and type for the function.
|
||||
// * maintain a list of the defined graph
|
||||
llvm::SmallVector<mlir::Type, 4> ret_types;
|
||||
llvm::SmallVector<mlir::Type, 4> arg_types;
|
||||
auto func_type = builder_.getFunctionType(arg_types, ret_types);
|
||||
auto llvmfunction = mlir::FuncOp::create(UnknownLoc(),
|
||||
graph.name(), func_type, /* attrs = */ {});
|
||||
auto &entryBlock = *llvmfunction.addEntryBlock();
|
||||
builder_.setInsertionPointToStart(&entryBlock);
|
||||
module_.push_back(llvmfunction);
|
||||
|
||||
//TODO: import the initializer
|
||||
//
|
||||
|
||||
//import the input tensors
|
||||
for (auto input : graph.input()) {
|
||||
ImportInputTensor(input);
|
||||
}
|
||||
|
||||
//import nodes in the graph
|
||||
auto node = graph.node();
|
||||
for (auto item: node) {
|
||||
ImportNode(item);
|
||||
}
|
||||
|
||||
//import the output tensors
|
||||
for (auto output : graph.output()) {
|
||||
ImportOutputTensor(output);
|
||||
}
|
||||
}
|
||||
|
||||
} ; //SGIRGenImpl class
|
||||
|
||||
} //namespace
|
||||
} //namespace onnf
|
||||
|
||||
namespace onnf {
|
||||
|
||||
int SGIRTest() {
|
||||
/*!
|
||||
* Generate SGIR with MLIR for a onnx model
|
||||
* @param model onnx model.
|
||||
* @return module mlir module generated for the onnx model
|
||||
*/
|
||||
mlir::OwningModuleRef SGIRImportModel(onnx::ModelProto model) {
|
||||
mlir::MLIRContext context;
|
||||
SGIRGenImpl mySGIRGen(context);
|
||||
auto module = mySGIRGen.ImportModel(model);
|
||||
module.dump();
|
||||
|
||||
mlir::OwningModuleRef module = SGIRGenImpl(context).mlirGen();
|
||||
if (!module)
|
||||
return 1;
|
||||
module->dump();
|
||||
return 0;
|
||||
return module;
|
||||
}
|
||||
|
||||
} //namespace onnf
|
||||
|
|
|
@ -8,7 +8,15 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "onnx/onnx_pb.h"
|
||||
|
||||
namespace mlir {
|
||||
class MLIRContext;
|
||||
|
@ -17,10 +25,11 @@ class OwningModuleRef;
|
|||
|
||||
namespace onnf {
|
||||
/*!
|
||||
* Test dummy
|
||||
* @return status, 0 for success, otherwise failure
|
||||
**/
|
||||
int SGIRTest();
|
||||
* Import an ONNX Model into SGIR
|
||||
* @param model onnx model.
|
||||
* @return MLIR::module generated for the ONNX model
|
||||
*/
|
||||
mlir::OwningModuleRef SGIRImportModel(onnx::ModelProto model);
|
||||
|
||||
} //namespace onnf
|
||||
|
||||
|
|
Loading…
Reference in New Issue