fix code format (#348)

This commit is contained in:
Tian Jin 2019-10-07 19:47:46 -04:00 committed by Doru Bercea
parent e1f6ae1336
commit cc39a92802
2 changed files with 59 additions and 64 deletions

View File

@ -6,10 +6,10 @@
//
//===----------------------------------------------------------------------===//
#include <regex>
#include <tuple>
#include <numeric>
#include <regex>
#include <string>
#include <tuple>
#include "mlir/Analysis/Verifier.h"
#include "mlir/Dialect/StandardOps/Ops.h"
@ -73,7 +73,6 @@ struct OnnxOnnfSymbolMapping {
onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
}
bool ContainKey(std::string name) {
return onnx_name2onnf_tensor.count(name) != 0;
}
@ -86,8 +85,8 @@ struct OnnxOnnfSymbolMapping {
};
class SGIRGenImpl {
public :
SGIRGenImpl(mlir::MLIRContext &context)
public:
SGIRGenImpl(mlir::MLIRContext& context)
: context_(context), builder_(&context) {
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
}
@ -97,16 +96,14 @@ public :
return module_;
}
private:
mlir::MLIRContext &context_;
private:
mlir::MLIRContext& context_;
mlir::ModuleOp module_;
mlir::OpBuilder builder_;
// mapping between string name and symbol
OnnxOnnfSymbolMapping sgir_symbols_;
mlir::Location UnknownLoc() {
return mlir::UnknownLoc::get(&context_);
}
mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); }
mlir::Type TypeConvert(onnx::TensorProto_DataType intype) {
return builder_.getF32Type();
@ -121,26 +118,28 @@ private:
int dim_numeric_size = shape_proto.dim()[i].dim_value();
if (dim_numeric_size > 0) {
dims.push_back(dim_numeric_size);
}else { // If dim_value < 0, then dim is parametric.
//TODO Verify the unknown dim size in MLIR
} else { // If dim_value < 0, then dim is parametric.
// TODO Verify the unknown dim size in MLIR
dims.push_back(-1);
}
} else {
//TODO How to represent variable length
// TODO How to represent variable length
dims.push_back(-1);
}
}
if (!sgir_symbols_.ContainKey(input_tensor_legalized_name)) {
mlir::Type elementType = TypeConvert(input.type().tensor_type().elem_type());
mlir::Type elementType =
TypeConvert(input.type().tensor_type().elem_type());
llvm::ArrayRef<int64_t> llvmdimsAR(dims.data(), dims.size());
auto dataType = builder_.getTensorType(llvmdimsAR, elementType);
mlir::OperationState result(UnknownLoc(), "sgir.input "+input_tensor_legalized_name);
mlir::OperationState result(
UnknownLoc(), "sgir.input " + input_tensor_legalized_name);
result.addTypes(dataType);
auto op = builder_.createOperation(result);
auto value = op->getResult(0);
sgir_symbols_.AddMapping(input_tensor_legalized_name, value);
} else {
//TODO Should not happen
// TODO Should not happen
}
}
@ -151,73 +150,71 @@ private:
inputs.push_back(sgir_symbols_.GetTensorByOnnxName(item));
}
}
mlir::OperationState result(UnknownLoc(), "SGIR."+node.op_type());
mlir::OperationState result(UnknownLoc(), "SGIR." + node.op_type());
for (auto item : node.output()) {
result.addTypes(builder_.getTensorType(builder_.getF32Type()));
}
result.addOperands(inputs);
auto op = builder_.createOperation(result);
for (int i=0 ; i< node.output().size(); i++) {
for (int i = 0; i < node.output().size(); i++) {
auto r = builder_.createOperation(result)->getResult(i);
sgir_symbols_.AddMapping(legalize_name(node.output()[i]), r);
}
//TODO more info from node: attributes
// TODO more info from node: attributes
}
void ImportOutputTensor(onnx::ValueInfoProto& output) {
if(sgir_symbols_.ContainKey(legalize_name(output.name()))) {
mlir::OperationState result(UnknownLoc(), "sgir.output "+output.name());
if (sgir_symbols_.ContainKey(legalize_name(output.name()))) {
mlir::OperationState result(UnknownLoc(), "sgir.output " + output.name());
result.addTypes(builder_.getTensorType(builder_.getF32Type()));
result.addOperands(sgir_symbols_.GetTensorByOnnxName(output.name()));
builder_.createOperation(result);
} else {
//TODO: Why not in the symbol table? something is wrong
// TODO: Why not in the symbol table? something is wrong
}
}
void ImportGraph(onnx::GraphProto graph) {
//create a function for the graph
//TODO:
// create a function for the graph
// TODO:
// * get name and type for the function.
// * maintain a list of the defined graph
llvm::SmallVector<mlir::Type, 4> ret_types;
llvm::SmallVector<mlir::Type, 4> arg_types;
auto func_type = builder_.getFunctionType(arg_types, ret_types);
auto llvmfunction = mlir::FuncOp::create(UnknownLoc(),
graph.name(), func_type, /* attrs = */ {});
auto &entryBlock = *llvmfunction.addEntryBlock();
auto llvmfunction = mlir::FuncOp::create(
UnknownLoc(), graph.name(), func_type, /* attrs = */ {});
auto& entryBlock = *llvmfunction.addEntryBlock();
builder_.setInsertionPointToStart(&entryBlock);
module_.push_back(llvmfunction);
//TODO: import the initializer
// TODO: import the initializer
//
//import the input tensors
// import the input tensors
for (auto input : graph.input()) {
ImportInputTensor(input);
}
//import nodes in the graph
// import nodes in the graph
auto node = graph.node();
for (auto item: node) {
for (auto item : node) {
ImportNode(item);
}
//import the output tensors
// import the output tensors
for (auto output : graph.output()) {
ImportOutputTensor(output);
}
}
} ; //SGIRGenImpl class
}; // SGIRGenImpl class
} //namespace
} //namespace onnf
} // namespace
} // namespace dlc
namespace onnf {
/*!
* Generate SGIR with MLIR for a onnx model
* @param model onnx model.
@ -232,5 +229,4 @@ mlir::OwningModuleRef SGIRImportModel(onnx::ModelProto model) {
return module;
}
} //namespace onnf
} // namespace onnf

View File

@ -24,12 +24,11 @@ class OwningModuleRef;
} // namespace mlir
namespace onnf {
/*!
/*!
* Import an ONNX Model into SGIR
* @param model onnx model.
* @return MLIR::module generated for the ONNX model
*/
mlir::OwningModuleRef SGIRImportModel(onnx::ModelProto model);
} //namespace onnf
mlir::OwningModuleRef SGIRImportModel(onnx::ModelProto model);
} // namespace onnf