258 lines
		
	
	
		
			8.0 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			258 lines
		
	
	
		
			8.0 KiB
		
	
	
	
		
			C++
		
	
	
	
| //===- frontend_dialect_transformer.cpp - MLIR Operations -----------------===//
 | |
| //
 | |
| // Copyright 2019 The IBM Research Authors. 
 | |
| //
 | |
| // =============================================================================
 | |
| //
 | |
| // This file transforms the input to available MLIR dialects that can represent
 | |
| // the operations of the model. Models use the ONNX dialect and any other
 | |
| // extension dialects that comprise the the operations not supported or covered
 | |
| // by the ONNX specification.
 | |
| //
 | |
| // A `frontend` placeholder dialect is used to encode operations that are not
 | |
| // covered by any existing dialects.
 | |
| //
 | |
| //===----------------------------------------------------------------------===//
 | |
| 
 | |
| #include <numeric>
 | |
| #include <regex>
 | |
| #include <string>
 | |
| #include <tuple>
 | |
| 
 | |
| #include "mlir/Analysis/Verifier.h"
 | |
| #include "mlir/Dialect/StandardOps/Ops.h"
 | |
| #include "mlir/IR/Attributes.h"
 | |
| #include "mlir/IR/Builders.h"
 | |
| #include "mlir/IR/Function.h"
 | |
| #include "mlir/IR/Location.h"
 | |
| #include "mlir/IR/MLIRContext.h"
 | |
| #include "mlir/IR/Module.h"
 | |
| #include "mlir/IR/StandardTypes.h"
 | |
| #include "mlir/IR/Types.h"
 | |
| 
 | |
| #include "llvm/ADT/STLExtras.h"
 | |
| #include "llvm/ADT/ScopedHashTable.h"
 | |
| #include "llvm/Support/raw_ostream.h"
 | |
| 
 | |
| #include "frontend_dialect_transformer.hpp"
 | |
| #include "src/compiler/dialect/onnx/onnx_ops.hpp"
 | |
| 
 | |
| namespace onnf {
 | |
| namespace {
 | |
| 
 | |
| void replaceAll(
 | |
|     std::string& str, const std::string& from, const std::string& to) {
 | |
|   if (from.empty())
 | |
|     return;
 | |
|   size_t start_pos = 0;
 | |
|   while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
 | |
|     str.replace(start_pos, from.length(), to);
 | |
|     start_pos += to.length();  // In case 'to' contains 'from', like replacing
 | |
|                                // 'x' with 'yx'
 | |
|   }
 | |
| }
 | |
| 
 | |
| std::string legalize_name(std::string name) {
 | |
|   std::replace(name.begin(), name.end(), '/', '_');
 | |
|   std::replace(name.begin(), name.end(), '-', '_');
 | |
|   replaceAll(name, ":", "_colon_");
 | |
|   // If tensor name starts with a number, prepend n to make it a legal c++
 | |
|   // identifier.
 | |
|   if (name.size() > 0 && isdigit(name.at(0)))
 | |
|     name.insert(0, 1, 'n');
 | |
|   return name;
 | |
| }
 | |
| 
 | |
| struct OnnxOnnfSymbolMapping {
 | |
|   /*!
 | |
|    *  Get MLIR tensor by onnx tensor name.
 | |
|    *  @param name onnx tensor name.
 | |
|    *  @return onnf tensor corresponding to `name`.
 | |
|    */
 | |
|   mlir::Value* GetTensorByOnnxName(std::string name) {
 | |
|     return onnx_name2onnf_tensor.at(legalize_name(name));
 | |
|   }
 | |
| 
 | |
|   /*!
 | |
|    *  Add a new mapping from onnx tensor name to MLIR symbol.
 | |
|    *  @param name onnx tensor name.
 | |
|    *  @param tensor MLIR Value* pointer.
 | |
|    */
 | |
|   void AddMapping(std::string name, mlir::Value* tensor) {
 | |
|     onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
 | |
|   }
 | |
| 
 | |
|   bool ContainKey(std::string name) {
 | |
|     return onnx_name2onnf_tensor.count(name) != 0;
 | |
|   }
 | |
| 
 | |
|  private:
 | |
|   /*!
 | |
|    *  mapping from onnx tensor names to MLIR tensor.
 | |
|    */
 | |
|   std::map<std::string, mlir::Value*> onnx_name2onnf_tensor;
 | |
| };
 | |
| 
 | |
| class FrontendGenImpl {
 | |
|  public:
 | |
|   FrontendGenImpl(mlir::MLIRContext& context)
 | |
|       : context_(context), builder_(&context) {
 | |
|     module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
 | |
|   }
 | |
| 
 | |
|   mlir::ModuleOp ImportONNXModel(onnx::ModelProto model) {
 | |
|     ImportGraph(model.graph());
 | |
|     return module_;
 | |
|   }
 | |
| 
 | |
|  private:
 | |
|   mlir::MLIRContext& context_;
 | |
|   mlir::ModuleOp module_;
 | |
|   mlir::OpBuilder builder_;
 | |
|   // mapping between string name and symbol
 | |
|   OnnxOnnfSymbolMapping frontend_symbols_;
 | |
| 
 | |
|   mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); }
 | |
| 
 | |
|   mlir::Type TypeConvert(onnx::TensorProto_DataType intype) {
 | |
|     return builder_.getF32Type();
 | |
|   }
 | |
| 
 | |
|   void ImportInputTensor(onnx::ValueInfoProto& input) {
 | |
|     std::vector<int64_t> dims;
 | |
|     auto shape_proto = input.type().tensor_type().shape();
 | |
|     auto input_tensor_legalized_name = legalize_name(input.name());
 | |
|     for (int i = 0; i < shape_proto.dim_size(); i++) {
 | |
|       if (shape_proto.dim()[i].dim_value()) {
 | |
|         int dim_numeric_size = shape_proto.dim()[i].dim_value();
 | |
|         if (dim_numeric_size > 0) {
 | |
|           dims.push_back(dim_numeric_size);
 | |
|         } else {  // If dim_value < 0, then dim is parametric.
 | |
|                   // TODO Verify the unknown dim size in MLIR
 | |
|           dims.push_back(-1);
 | |
|         }
 | |
|       } else {
 | |
|         // TODO How to represent variable length
 | |
|         dims.push_back(-1);
 | |
|       }
 | |
|     }
 | |
|     if (!frontend_symbols_.ContainKey(input_tensor_legalized_name)) {
 | |
|       mlir::Type elementType =
 | |
|           TypeConvert(input.type().tensor_type().elem_type());
 | |
|       llvm::ArrayRef<int64_t> llvmdimsAR(dims.data(), dims.size());
 | |
|       auto dataType = mlir::RankedTensorType::get(llvmdimsAR, elementType);
 | |
|       mlir::OperationState result(
 | |
|           UnknownLoc(), "frontend.input " + input_tensor_legalized_name);
 | |
|       result.addTypes(dataType);
 | |
|       auto op = builder_.createOperation(result);
 | |
|       auto value = op->getResult(0);
 | |
|       frontend_symbols_.AddMapping(input_tensor_legalized_name, value);
 | |
|     } else {
 | |
|       // TODO  Should not happen
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   void ImportNode(onnx::NodeProto node) {
 | |
|     std::vector<mlir::Value*> inputs;
 | |
|     for (auto item : node.input()) {
 | |
|       if (frontend_symbols_.ContainKey(legalize_name(item))) {
 | |
|         inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
 | |
|       }
 | |
|     }
 | |
| 
 | |
|     // Handle ONNX Add Operation by using its representation in the
 | |
|     // ONNX Dialect.
 | |
|     llvm::StringRef OpName = node.op_type();
 | |
|     if (OpName == "Add") {
 | |
|       auto op =
 | |
|           builder_.create<mlir::ONNXAddOp>(UnknownLoc(), inputs[0], inputs[1]);
 | |
|       frontend_symbols_.AddMapping(legalize_name(node.output()[0]), op.getResult());
 | |
|       return;
 | |
|     }
 | |
| 
 | |
|     // Old way of doing things.
 | |
|     mlir::OperationState result(UnknownLoc(), "frontend." + node.op_type());
 | |
|     for (auto item : node.output()) {
 | |
|       result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type()));
 | |
|     }
 | |
|     result.addOperands(inputs);
 | |
|     auto op = builder_.createOperation(result);
 | |
|     for (int i = 0; i < node.output().size(); i++) {
 | |
|       auto r = op->getResult(i);
 | |
|       frontend_symbols_.AddMapping(legalize_name(node.output()[i]), r);
 | |
|     }
 | |
| 
 | |
|     // TODO more info from node: attributes
 | |
|   }
 | |
| 
 | |
|   void ImportOutputTensor(onnx::ValueInfoProto& output) {
 | |
|     if (frontend_symbols_.ContainKey(legalize_name(output.name()))) {
 | |
|       mlir::OperationState result(UnknownLoc(), "frontend.output " + output.name());
 | |
|       result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type()));
 | |
|       result.addOperands(frontend_symbols_.GetTensorByOnnxName(output.name()));
 | |
|       builder_.createOperation(result);
 | |
|     } else {
 | |
|       // TODO: Why not in the symbol table? something is wrong
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   void ImportGraph(onnx::GraphProto graph) {
 | |
|     // create a function for the graph
 | |
|     // TODO:
 | |
|     //  * get name and type for the function.
 | |
|     //  * maintain a list of the defined graph
 | |
|     llvm::SmallVector<mlir::Type, 4> ret_types;
 | |
|     llvm::SmallVector<mlir::Type, 4> arg_types;
 | |
|     auto func_type = builder_.getFunctionType(arg_types, ret_types);
 | |
|     auto llvmfunction = mlir::FuncOp::create(
 | |
|         UnknownLoc(), graph.name(), func_type, /* attrs = */ {});
 | |
|     auto& entryBlock = *llvmfunction.addEntryBlock();
 | |
|     builder_.setInsertionPointToStart(&entryBlock);
 | |
|     module_.push_back(llvmfunction);
 | |
| 
 | |
|     // TODO: import the initializer
 | |
|     //
 | |
| 
 | |
|     // import the input tensors
 | |
|     for (auto input : graph.input()) {
 | |
|       ImportInputTensor(input);
 | |
|     }
 | |
| 
 | |
|     // import nodes in the graph
 | |
|     auto node = graph.node();
 | |
|     for (auto item : node) {
 | |
|       ImportNode(item);
 | |
|     }
 | |
| 
 | |
|     // import the output tensors
 | |
|     for (auto output : graph.output()) {
 | |
|       ImportOutputTensor(output);
 | |
|     }
 | |
|   }
 | |
| 
 | |
| };  // FrontendGenImpl class
 | |
| 
 | |
| }  // namespace
 | |
| }  // namespace onnf
 | |
| 
 | |
| namespace onnf {
 | |
| 
 | |
| mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model) {
 | |
|   mlir::MLIRContext context;
 | |
|   FrontendGenImpl myONNXGen(context);
 | |
|   auto module = myONNXGen.ImportONNXModel(model);
 | |
|   module.dump();
 | |
| 
 | |
|   return module;
 | |
| }
 | |
| 
 | |
| mlir::OwningModuleRef ImportFrontendModelFile(std::string model_fname) {
 | |
|   onnx::ModelProto model;
 | |
|   std::fstream input(model_fname, std::ios::in | std::ios::binary);
 | |
| 
 | |
|   auto parse_success = model.ParseFromIstream(&input);
 | |
| 
 | |
|   return ImportFrontendModel(model);
 | |
| }
 | |
| }  // namespace onnf
 |