fix code format (#348)
This commit is contained in:
parent
e1f6ae1336
commit
cc39a92802
|
@ -6,10 +6,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include <regex>
|
|
||||||
#include <tuple>
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
#include <regex>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
|
|
||||||
#include "mlir/Analysis/Verifier.h"
|
#include "mlir/Analysis/Verifier.h"
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
|
@ -56,38 +56,37 @@ std::string legalize_name(std::string name) {
|
||||||
|
|
||||||
struct OnnxOnnfSymbolMapping {
|
struct OnnxOnnfSymbolMapping {
|
||||||
/*!
|
/*!
|
||||||
* Get MLIR tensor by onnx tensor name.
|
* Get MLIR tensor by onnx tensor name.
|
||||||
* @param name onnx tensor name.
|
* @param name onnx tensor name.
|
||||||
* @return onnf tensor corresponding to `name`.
|
* @return onnf tensor corresponding to `name`.
|
||||||
*/
|
*/
|
||||||
mlir::Value* GetTensorByOnnxName(std::string name) {
|
mlir::Value* GetTensorByOnnxName(std::string name) {
|
||||||
return onnx_name2onnf_tensor.at(legalize_name(name));
|
return onnx_name2onnf_tensor.at(legalize_name(name));
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Add a new mapping from onnx tensor name to MLIR symbol.
|
* Add a new mapping from onnx tensor name to MLIR symbol.
|
||||||
* @param name onnx tensor name.
|
* @param name onnx tensor name.
|
||||||
* @param tensor MLIR Value* pointer.
|
* @param tensor MLIR Value* pointer.
|
||||||
*/
|
*/
|
||||||
void AddMapping(std::string name, mlir::Value* tensor) {
|
void AddMapping(std::string name, mlir::Value* tensor) {
|
||||||
onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
|
onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
bool ContainKey(std::string name) {
|
bool ContainKey(std::string name) {
|
||||||
return onnx_name2onnf_tensor.count(name) != 0;
|
return onnx_name2onnf_tensor.count(name) != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/*!
|
/*!
|
||||||
* mapping from onnx tensor names to MLIR tensor.
|
* mapping from onnx tensor names to MLIR tensor.
|
||||||
*/
|
*/
|
||||||
std::map<std::string, mlir::Value*> onnx_name2onnf_tensor;
|
std::map<std::string, mlir::Value*> onnx_name2onnf_tensor;
|
||||||
};
|
};
|
||||||
|
|
||||||
class SGIRGenImpl {
|
class SGIRGenImpl {
|
||||||
public :
|
public:
|
||||||
SGIRGenImpl(mlir::MLIRContext &context)
|
SGIRGenImpl(mlir::MLIRContext& context)
|
||||||
: context_(context), builder_(&context) {
|
: context_(context), builder_(&context) {
|
||||||
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||||
}
|
}
|
||||||
|
@ -97,16 +96,14 @@ public :
|
||||||
return module_;
|
return module_;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
mlir::MLIRContext &context_;
|
mlir::MLIRContext& context_;
|
||||||
mlir::ModuleOp module_;
|
mlir::ModuleOp module_;
|
||||||
mlir::OpBuilder builder_;
|
mlir::OpBuilder builder_;
|
||||||
// mapping between string name and symbol
|
// mapping between string name and symbol
|
||||||
OnnxOnnfSymbolMapping sgir_symbols_;
|
OnnxOnnfSymbolMapping sgir_symbols_;
|
||||||
|
|
||||||
mlir::Location UnknownLoc() {
|
mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); }
|
||||||
return mlir::UnknownLoc::get(&context_);
|
|
||||||
}
|
|
||||||
|
|
||||||
mlir::Type TypeConvert(onnx::TensorProto_DataType intype) {
|
mlir::Type TypeConvert(onnx::TensorProto_DataType intype) {
|
||||||
return builder_.getF32Type();
|
return builder_.getF32Type();
|
||||||
|
@ -121,26 +118,28 @@ private:
|
||||||
int dim_numeric_size = shape_proto.dim()[i].dim_value();
|
int dim_numeric_size = shape_proto.dim()[i].dim_value();
|
||||||
if (dim_numeric_size > 0) {
|
if (dim_numeric_size > 0) {
|
||||||
dims.push_back(dim_numeric_size);
|
dims.push_back(dim_numeric_size);
|
||||||
}else { // If dim_value < 0, then dim is parametric.
|
} else { // If dim_value < 0, then dim is parametric.
|
||||||
//TODO Verify the unknown dim size in MLIR
|
// TODO Verify the unknown dim size in MLIR
|
||||||
dims.push_back(-1);
|
dims.push_back(-1);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
//TODO How to represent variable length
|
// TODO How to represent variable length
|
||||||
dims.push_back(-1);
|
dims.push_back(-1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!sgir_symbols_.ContainKey(input_tensor_legalized_name)) {
|
if (!sgir_symbols_.ContainKey(input_tensor_legalized_name)) {
|
||||||
mlir::Type elementType = TypeConvert(input.type().tensor_type().elem_type());
|
mlir::Type elementType =
|
||||||
|
TypeConvert(input.type().tensor_type().elem_type());
|
||||||
llvm::ArrayRef<int64_t> llvmdimsAR(dims.data(), dims.size());
|
llvm::ArrayRef<int64_t> llvmdimsAR(dims.data(), dims.size());
|
||||||
auto dataType = builder_.getTensorType(llvmdimsAR, elementType);
|
auto dataType = builder_.getTensorType(llvmdimsAR, elementType);
|
||||||
mlir::OperationState result(UnknownLoc(), "sgir.input "+input_tensor_legalized_name);
|
mlir::OperationState result(
|
||||||
|
UnknownLoc(), "sgir.input " + input_tensor_legalized_name);
|
||||||
result.addTypes(dataType);
|
result.addTypes(dataType);
|
||||||
auto op = builder_.createOperation(result);
|
auto op = builder_.createOperation(result);
|
||||||
auto value = op->getResult(0);
|
auto value = op->getResult(0);
|
||||||
sgir_symbols_.AddMapping(input_tensor_legalized_name, value);
|
sgir_symbols_.AddMapping(input_tensor_legalized_name, value);
|
||||||
} else {
|
} else {
|
||||||
//TODO Should not happen
|
// TODO Should not happen
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -151,73 +150,71 @@ private:
|
||||||
inputs.push_back(sgir_symbols_.GetTensorByOnnxName(item));
|
inputs.push_back(sgir_symbols_.GetTensorByOnnxName(item));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
mlir::OperationState result(UnknownLoc(), "SGIR."+node.op_type());
|
mlir::OperationState result(UnknownLoc(), "SGIR." + node.op_type());
|
||||||
for (auto item : node.output()) {
|
for (auto item : node.output()) {
|
||||||
result.addTypes(builder_.getTensorType(builder_.getF32Type()));
|
result.addTypes(builder_.getTensorType(builder_.getF32Type()));
|
||||||
}
|
}
|
||||||
result.addOperands(inputs);
|
result.addOperands(inputs);
|
||||||
auto op = builder_.createOperation(result);
|
auto op = builder_.createOperation(result);
|
||||||
for (int i=0 ; i< node.output().size(); i++) {
|
for (int i = 0; i < node.output().size(); i++) {
|
||||||
auto r = builder_.createOperation(result)->getResult(i);
|
auto r = builder_.createOperation(result)->getResult(i);
|
||||||
sgir_symbols_.AddMapping(legalize_name(node.output()[i]), r);
|
sgir_symbols_.AddMapping(legalize_name(node.output()[i]), r);
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO more info from node: attributes
|
// TODO more info from node: attributes
|
||||||
}
|
}
|
||||||
|
|
||||||
void ImportOutputTensor(onnx::ValueInfoProto& output) {
|
void ImportOutputTensor(onnx::ValueInfoProto& output) {
|
||||||
if(sgir_symbols_.ContainKey(legalize_name(output.name()))) {
|
if (sgir_symbols_.ContainKey(legalize_name(output.name()))) {
|
||||||
mlir::OperationState result(UnknownLoc(), "sgir.output "+output.name());
|
mlir::OperationState result(UnknownLoc(), "sgir.output " + output.name());
|
||||||
result.addTypes(builder_.getTensorType(builder_.getF32Type()));
|
result.addTypes(builder_.getTensorType(builder_.getF32Type()));
|
||||||
result.addOperands(sgir_symbols_.GetTensorByOnnxName(output.name()));
|
result.addOperands(sgir_symbols_.GetTensorByOnnxName(output.name()));
|
||||||
builder_.createOperation(result);
|
builder_.createOperation(result);
|
||||||
} else {
|
} else {
|
||||||
//TODO: Why not in the symbol table? something is wrong
|
// TODO: Why not in the symbol table? something is wrong
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void ImportGraph(onnx::GraphProto graph) {
|
void ImportGraph(onnx::GraphProto graph) {
|
||||||
//create a function for the graph
|
// create a function for the graph
|
||||||
//TODO:
|
// TODO:
|
||||||
// * get name and type for the function.
|
// * get name and type for the function.
|
||||||
// * maintain a list of the defined graph
|
// * maintain a list of the defined graph
|
||||||
llvm::SmallVector<mlir::Type, 4> ret_types;
|
llvm::SmallVector<mlir::Type, 4> ret_types;
|
||||||
llvm::SmallVector<mlir::Type, 4> arg_types;
|
llvm::SmallVector<mlir::Type, 4> arg_types;
|
||||||
auto func_type = builder_.getFunctionType(arg_types, ret_types);
|
auto func_type = builder_.getFunctionType(arg_types, ret_types);
|
||||||
auto llvmfunction = mlir::FuncOp::create(UnknownLoc(),
|
auto llvmfunction = mlir::FuncOp::create(
|
||||||
graph.name(), func_type, /* attrs = */ {});
|
UnknownLoc(), graph.name(), func_type, /* attrs = */ {});
|
||||||
auto &entryBlock = *llvmfunction.addEntryBlock();
|
auto& entryBlock = *llvmfunction.addEntryBlock();
|
||||||
builder_.setInsertionPointToStart(&entryBlock);
|
builder_.setInsertionPointToStart(&entryBlock);
|
||||||
module_.push_back(llvmfunction);
|
module_.push_back(llvmfunction);
|
||||||
|
|
||||||
//TODO: import the initializer
|
// TODO: import the initializer
|
||||||
//
|
//
|
||||||
|
|
||||||
//import the input tensors
|
// import the input tensors
|
||||||
for (auto input : graph.input()) {
|
for (auto input : graph.input()) {
|
||||||
ImportInputTensor(input);
|
ImportInputTensor(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
//import nodes in the graph
|
// import nodes in the graph
|
||||||
auto node = graph.node();
|
auto node = graph.node();
|
||||||
for (auto item: node) {
|
for (auto item : node) {
|
||||||
ImportNode(item);
|
ImportNode(item);
|
||||||
}
|
}
|
||||||
|
|
||||||
//import the output tensors
|
// import the output tensors
|
||||||
for (auto output : graph.output()) {
|
for (auto output : graph.output()) {
|
||||||
ImportOutputTensor(output);
|
ImportOutputTensor(output);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} ; //SGIRGenImpl class
|
}; // SGIRGenImpl class
|
||||||
|
|
||||||
} //namespace
|
} // namespace
|
||||||
} //namespace onnf
|
} // namespace dlc
|
||||||
|
|
||||||
namespace onnf {
|
namespace onnf {
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Generate SGIR with MLIR for a onnx model
|
* Generate SGIR with MLIR for a onnx model
|
||||||
* @param model onnx model.
|
* @param model onnx model.
|
||||||
|
@ -232,5 +229,4 @@ mlir::OwningModuleRef SGIRImportModel(onnx::ModelProto model) {
|
||||||
return module;
|
return module;
|
||||||
}
|
}
|
||||||
|
|
||||||
} //namespace onnf
|
} // namespace onnf
|
||||||
|
|
||||||
|
|
|
@ -21,15 +21,14 @@
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class MLIRContext;
|
class MLIRContext;
|
||||||
class OwningModuleRef;
|
class OwningModuleRef;
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
namespace onnf {
|
namespace onnf {
|
||||||
/*!
|
/*!
|
||||||
* Import an ONNX Model into SGIR
|
* Import an ONNX Model into SGIR
|
||||||
* @param model onnx model.
|
* @param model onnx model.
|
||||||
* @return MLIR::module generated for the ONNX model
|
* @return MLIR::module generated for the ONNX model
|
||||||
*/
|
*/
|
||||||
mlir::OwningModuleRef SGIRImportModel(onnx::ModelProto model);
|
mlir::OwningModuleRef SGIRImportModel(onnx::ModelProto model);
|
||||||
|
|
||||||
} //namespace onnf
|
|
||||||
|
|
||||||
|
} // namespace onnf
|
||||||
|
|
Loading…
Reference in New Issue