fix code format (#348)

This commit is contained in:
Tian Jin 2019-10-07 19:47:46 -04:00 committed by Doru Bercea
parent e1f6ae1336
commit cc39a92802
2 changed files with 59 additions and 64 deletions

View File

@ -6,10 +6,10 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include <regex>
#include <tuple>
#include <numeric> #include <numeric>
#include <regex>
#include <string> #include <string>
#include <tuple>
#include "mlir/Analysis/Verifier.h" #include "mlir/Analysis/Verifier.h"
#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/StandardOps/Ops.h"
@ -73,7 +73,6 @@ struct OnnxOnnfSymbolMapping {
onnx_name2onnf_tensor.emplace(legalize_name(name), tensor); onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
} }
bool ContainKey(std::string name) { bool ContainKey(std::string name) {
return onnx_name2onnf_tensor.count(name) != 0; return onnx_name2onnf_tensor.count(name) != 0;
} }
@ -86,8 +85,8 @@ struct OnnxOnnfSymbolMapping {
}; };
class SGIRGenImpl { class SGIRGenImpl {
public : public:
SGIRGenImpl(mlir::MLIRContext &context) SGIRGenImpl(mlir::MLIRContext& context)
: context_(context), builder_(&context) { : context_(context), builder_(&context) {
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
} }
@ -97,16 +96,14 @@ public :
return module_; return module_;
} }
private: private:
mlir::MLIRContext &context_; mlir::MLIRContext& context_;
mlir::ModuleOp module_; mlir::ModuleOp module_;
mlir::OpBuilder builder_; mlir::OpBuilder builder_;
// mapping between string name and symbol // mapping between string name and symbol
OnnxOnnfSymbolMapping sgir_symbols_; OnnxOnnfSymbolMapping sgir_symbols_;
mlir::Location UnknownLoc() { mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); }
return mlir::UnknownLoc::get(&context_);
}
mlir::Type TypeConvert(onnx::TensorProto_DataType intype) { mlir::Type TypeConvert(onnx::TensorProto_DataType intype) {
return builder_.getF32Type(); return builder_.getF32Type();
@ -121,26 +118,28 @@ private:
int dim_numeric_size = shape_proto.dim()[i].dim_value(); int dim_numeric_size = shape_proto.dim()[i].dim_value();
if (dim_numeric_size > 0) { if (dim_numeric_size > 0) {
dims.push_back(dim_numeric_size); dims.push_back(dim_numeric_size);
}else { // If dim_value < 0, then dim is parametric. } else { // If dim_value < 0, then dim is parametric.
//TODO Verify the unknown dim size in MLIR // TODO Verify the unknown dim size in MLIR
dims.push_back(-1); dims.push_back(-1);
} }
} else { } else {
//TODO How to represent variable length // TODO How to represent variable length
dims.push_back(-1); dims.push_back(-1);
} }
} }
if (!sgir_symbols_.ContainKey(input_tensor_legalized_name)) { if (!sgir_symbols_.ContainKey(input_tensor_legalized_name)) {
mlir::Type elementType = TypeConvert(input.type().tensor_type().elem_type()); mlir::Type elementType =
TypeConvert(input.type().tensor_type().elem_type());
llvm::ArrayRef<int64_t> llvmdimsAR(dims.data(), dims.size()); llvm::ArrayRef<int64_t> llvmdimsAR(dims.data(), dims.size());
auto dataType = builder_.getTensorType(llvmdimsAR, elementType); auto dataType = builder_.getTensorType(llvmdimsAR, elementType);
mlir::OperationState result(UnknownLoc(), "sgir.input "+input_tensor_legalized_name); mlir::OperationState result(
UnknownLoc(), "sgir.input " + input_tensor_legalized_name);
result.addTypes(dataType); result.addTypes(dataType);
auto op = builder_.createOperation(result); auto op = builder_.createOperation(result);
auto value = op->getResult(0); auto value = op->getResult(0);
sgir_symbols_.AddMapping(input_tensor_legalized_name, value); sgir_symbols_.AddMapping(input_tensor_legalized_name, value);
} else { } else {
//TODO Should not happen // TODO Should not happen
} }
} }
@ -151,73 +150,71 @@ private:
inputs.push_back(sgir_symbols_.GetTensorByOnnxName(item)); inputs.push_back(sgir_symbols_.GetTensorByOnnxName(item));
} }
} }
mlir::OperationState result(UnknownLoc(), "SGIR."+node.op_type()); mlir::OperationState result(UnknownLoc(), "SGIR." + node.op_type());
for (auto item : node.output()) { for (auto item : node.output()) {
result.addTypes(builder_.getTensorType(builder_.getF32Type())); result.addTypes(builder_.getTensorType(builder_.getF32Type()));
} }
result.addOperands(inputs); result.addOperands(inputs);
auto op = builder_.createOperation(result); auto op = builder_.createOperation(result);
for (int i=0 ; i< node.output().size(); i++) { for (int i = 0; i < node.output().size(); i++) {
auto r = builder_.createOperation(result)->getResult(i); auto r = builder_.createOperation(result)->getResult(i);
sgir_symbols_.AddMapping(legalize_name(node.output()[i]), r); sgir_symbols_.AddMapping(legalize_name(node.output()[i]), r);
} }
//TODO more info from node: attributes // TODO more info from node: attributes
} }
void ImportOutputTensor(onnx::ValueInfoProto& output) { void ImportOutputTensor(onnx::ValueInfoProto& output) {
if(sgir_symbols_.ContainKey(legalize_name(output.name()))) { if (sgir_symbols_.ContainKey(legalize_name(output.name()))) {
mlir::OperationState result(UnknownLoc(), "sgir.output "+output.name()); mlir::OperationState result(UnknownLoc(), "sgir.output " + output.name());
result.addTypes(builder_.getTensorType(builder_.getF32Type())); result.addTypes(builder_.getTensorType(builder_.getF32Type()));
result.addOperands(sgir_symbols_.GetTensorByOnnxName(output.name())); result.addOperands(sgir_symbols_.GetTensorByOnnxName(output.name()));
builder_.createOperation(result); builder_.createOperation(result);
} else { } else {
//TODO: Why not in the symbol table? something is wrong // TODO: Why not in the symbol table? something is wrong
} }
} }
void ImportGraph(onnx::GraphProto graph) { void ImportGraph(onnx::GraphProto graph) {
//create a function for the graph // create a function for the graph
//TODO: // TODO:
// * get name and type for the function. // * get name and type for the function.
// * maintain a list of the defined graph // * maintain a list of the defined graph
llvm::SmallVector<mlir::Type, 4> ret_types; llvm::SmallVector<mlir::Type, 4> ret_types;
llvm::SmallVector<mlir::Type, 4> arg_types; llvm::SmallVector<mlir::Type, 4> arg_types;
auto func_type = builder_.getFunctionType(arg_types, ret_types); auto func_type = builder_.getFunctionType(arg_types, ret_types);
auto llvmfunction = mlir::FuncOp::create(UnknownLoc(), auto llvmfunction = mlir::FuncOp::create(
graph.name(), func_type, /* attrs = */ {}); UnknownLoc(), graph.name(), func_type, /* attrs = */ {});
auto &entryBlock = *llvmfunction.addEntryBlock(); auto& entryBlock = *llvmfunction.addEntryBlock();
builder_.setInsertionPointToStart(&entryBlock); builder_.setInsertionPointToStart(&entryBlock);
module_.push_back(llvmfunction); module_.push_back(llvmfunction);
//TODO: import the initializer // TODO: import the initializer
// //
//import the input tensors // import the input tensors
for (auto input : graph.input()) { for (auto input : graph.input()) {
ImportInputTensor(input); ImportInputTensor(input);
} }
//import nodes in the graph // import nodes in the graph
auto node = graph.node(); auto node = graph.node();
for (auto item: node) { for (auto item : node) {
ImportNode(item); ImportNode(item);
} }
//import the output tensors // import the output tensors
for (auto output : graph.output()) { for (auto output : graph.output()) {
ImportOutputTensor(output); ImportOutputTensor(output);
} }
} }
} ; //SGIRGenImpl class }; // SGIRGenImpl class
} //namespace } // namespace
} //namespace onnf } // namespace dlc
namespace onnf { namespace onnf {
/*! /*!
* Generate SGIR with MLIR for a onnx model * Generate SGIR with MLIR for a onnx model
* @param model onnx model. * @param model onnx model.
@ -232,5 +229,4 @@ mlir::OwningModuleRef SGIRImportModel(onnx::ModelProto model) {
return module; return module;
} }
} //namespace onnf } // namespace onnf

View File

@ -24,12 +24,11 @@ class OwningModuleRef;
} // namespace mlir } // namespace mlir
namespace onnf { namespace onnf {
/*! /*!
* Import an ONNX Model into SGIR * Import an ONNX Model into SGIR
* @param model onnx model. * @param model onnx model.
* @return MLIR::module generated for the ONNX model * @return MLIR::module generated for the ONNX model
*/ */
mlir::OwningModuleRef SGIRImportModel(onnx::ModelProto model); mlir::OwningModuleRef SGIRImportModel(onnx::ModelProto model);
} //namespace onnf
} // namespace onnf