onnx-mlir/src/builder/frontend_dialect_transforme...

531 lines
19 KiB
C++

//===- frontend_dialect_transformer.cpp - MLIR Operations -----------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file transforms the input to available MLIR dialects that can represent
// the operations of the model. Models use the ONNX dialect and any other
// extension dialects that comprise the the operations not supported or covered
// by the ONNX specification.
//
// A `frontend` placeholder dialect is used to encode operations that are not
// covered by any existing dialects.
//
//===----------------------------------------------------------------------===//
#include <map>
#include <numeric>
#include <regex>
#include <string>
#include <tuple>
// Using backported variant.
// bstd = backported standard library.
#include <mpark/variant.hpp>
namespace bstd = mpark;
#include "mlir/Analysis/Verifier.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/raw_ostream.h"
#include "src/dialect/onnx/onnx_ops.hpp"
#include "frontend_dialect_transformer.hpp"
namespace onnf {
namespace {
void replaceAll(std::string &str, const std::string &from,
const std::string &to) {
if (from.empty())
return;
size_t start_pos = 0;
while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
str.replace(start_pos, from.length(), to);
start_pos += to.length(); // In case 'to' contains 'from', like replacing
// 'x' with 'yx'
}
}
std::string legalize_name(std::string name) {
std::replace(name.begin(), name.end(), '/', '_');
std::replace(name.begin(), name.end(), '-', '_');
replaceAll(name, ":", "_colon_");
// If tensor name starts with a number, prepend n to make it a legal c++
// identifier.
if (name.size() > 0 && isdigit(name.at(0)))
name.insert(0, 1, 'n');
return name;
}
struct OnnxOnnfSymbolMapping {
/*!
* Get MLIR tensor by onnx tensor name.
* @param name onnx tensor name.
* @return onnf tensor corresponding to `name`.
*/
mlir::Value GetTensorByOnnxName(const std::string &name) {
assert(onnx_name2onnf_tensor.find(legalize_name(name)) !=
onnx_name2onnf_tensor.end() &&
"Tensor not found");
return onnx_name2onnf_tensor.at(legalize_name(name));
}
/*!
* Add a new mapping from onnx tensor name to MLIR symbol.
* @param name onnx tensor name.
* @param tensor MLIR Value pointer.
*/
void AddMapping(const std::string &name, mlir::Value tensor) {
assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 &&
"Tensor already exists.");
onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
}
bool ContainKey(std::string name) {
return onnx_name2onnf_tensor.count(name) != 0;
}
private:
/*!
* mapping from onnx tensor names to MLIR tensor.
*/
std::map<std::string, mlir::Value> onnx_name2onnf_tensor;
};
class FrontendGenImpl {
public:
FrontendGenImpl(mlir::MLIRContext &context)
: context_(context), builder_(&context) {
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
}
mlir::ModuleOp ImportONNXModel(onnx::ModelProto model) {
ImportGraph(model.graph());
return module_;
}
private:
mlir::MLIRContext &context_;
mlir::ModuleOp module_;
mlir::OpBuilder builder_;
mlir::Value none_;
// mapping between string name and symbol
OnnxOnnfSymbolMapping frontend_symbols_;
mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); }
// Convert type to MLIR type.
// A complete list of types can be found in:
// <onnf-build-folder>/third_party/onnx/onnx/onnx.pb.h
mlir::Type convertONNXTypeToMLIRType(onnx::TensorProto_DataType onnxType) {
switch (onnxType) {
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16:
return builder_.getF16Type();
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
return builder_.getF32Type();
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE:
return builder_.getF64Type();
case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
return builder_.getIntegerType(8);
case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
return builder_.getIntegerType(16);
case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
return builder_.getIntegerType(32);
case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
return builder_.getIntegerType(64);
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
return builder_.getI1Type();
case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
assert(false && "Unsupported data type encountered.");
return nullptr;
}
}
/*!
* Import an onnx input tensor type by determining and recording its type
* in a list of input tensor mlir types.
* @param input onnx input tensor ValueInfoProto.
* @param arg_types list of mlir types representing types of graph input.
*/
void ImportInputTensorType(const onnx::ValueInfoProto &input,
llvm::SmallVector<mlir::Type, 4> &arg_types) {
std::vector<int64_t> dims;
auto shape_proto = input.type().tensor_type().shape();
auto input_tensor_legalized_name = legalize_name(input.name());
for (int i = 0; i < shape_proto.dim_size(); i++) {
if (shape_proto.dim()[i].dim_value()) {
int dim_numeric_size = shape_proto.dim()[i].dim_value();
assert(dim_numeric_size != 0 &&
"Parsed an input tensor with a dimension size of zero");
if (dim_numeric_size > 0) {
dims.push_back(dim_numeric_size);
} else { // If dim_value < 0, then dim is parametric.
// TODO Verify the unknown dim size in MLIR
dims.push_back(-1);
}
} else {
// TODO How to represent variable length
dims.push_back(-1);
}
}
mlir::Type elementType =
convertONNXTypeToMLIRType(input.type().tensor_type().elem_type());
llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
arg_types.emplace_back(
mlir::RankedTensorType::get(tensor_dims, elementType));
}
/*!
* Import a input tensor symbol by recording a new entry in frontend_symbols_
* recording the mapping between legalized onnx tensor name and mlir::Value
* for further lookup in computation node importing.
* @param input onnx input tensor ValueInfoProto.
* @param symbol mlir input argument.
*/
void ImportInputTensorSymbol(const onnx::ValueInfoProto &input,
mlir::Value symbol) {
auto input_tensor_legalized_name = legalize_name(input.name());
assert(!frontend_symbols_.ContainKey(input_tensor_legalized_name) &&
"Found duplicate legalized input tensor names.");
frontend_symbols_.AddMapping(input_tensor_legalized_name, symbol);
}
typedef bstd::variant<int64_t, std::vector<int64_t>, float,
std::vector<float>, std::string,
std::vector<std::string>>
AttrValueType;
struct ONNXAttrVisitor {
ONNXAttrVisitor(std::string name, mlir::OpBuilder &builder)
: _builder(builder), _name(std::move(name)) {}
// Op builder.
mlir::OpBuilder &_builder;
// Name of the attribute being inspected.
std::string _name;
mlir::NamedAttribute operator()(int64_t const &r) {
auto val = _builder.getI64IntegerAttr(r);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(std::vector<int64_t> const &ints) {
auto val = _builder.getI64ArrayAttr(ints);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(float const &r) {
auto val = _builder.getF32FloatAttr(r);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(std::vector<float> const &floats) {
auto val = _builder.getF32ArrayAttr(floats);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(std::string const &s) {
auto val = _builder.getStringAttr(s);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(std::vector<std::string> const &r) {
assert(false && "type of attribute value is not implemented");
auto val = _builder.getI32IntegerAttr(1);
return _builder.getNamedAttr(_name, val);
};
};
mlir::NamedAttribute convertNameValuePairToNamedAttribute(
std::pair<std::string, AttrValueType> nameAndVal) {
auto visitor = ONNXAttrVisitor(nameAndVal.first, builder_);
return mpark::visit(visitor, nameAndVal.second);
}
static std::pair<std::string, AttrValueType>
convertAttributeProtoToNameValuePair(onnx::AttributeProto &attr) {
AttrValueType val;
switch (attr.type()) {
case onnx::AttributeProto::FLOAT:
return std::make_pair(attr.name(), AttrValueType(attr.f()));
case onnx::AttributeProto::INT:
return std::make_pair(attr.name(), AttrValueType(attr.i()));
case onnx::AttributeProto::STRING:
return std::make_pair(attr.name(), AttrValueType(attr.s()));
case onnx::AttributeProto::FLOATS:
val = AttrValueType(
std::vector<float>(attr.floats().begin(), attr.floats().end()));
return std::make_pair(attr.name(), val);
case onnx::AttributeProto::INTS:
val = AttrValueType(
std::vector<int64_t>(attr.ints().begin(), attr.ints().end()));
return std::make_pair(attr.name(), val);
default:
assert(false && "datatype for attribute is not implemented");
break;
}
}
std::vector<mlir::NamedAttribute>
ImportNodeAttributes(const onnx::NodeProto &node) {
std::vector<mlir::NamedAttribute> attributes;
for (int i = 0; i < node.attribute_size(); ++i) {
auto attr = node.attribute(i);
auto nameValPair = convertAttributeProtoToNameValuePair(attr);
attributes.push_back(convertNameValuePairToNamedAttribute(nameValPair));
}
return attributes;
}
void ImportNodeGeneric(const onnx::NodeProto &node) {
std::vector<mlir::Value> inputs;
for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
}
}
mlir::OperationState result(UnknownLoc(), "frontend." + node.op_type());
for (auto item : node.output()) {
result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type()));
}
result.addOperands(inputs);
auto op = builder_.createOperation(result);
for (int i = 0; i < node.output().size(); i++) {
auto r = op->getResult(i);
frontend_symbols_.AddMapping(legalize_name(node.output()[i]), r);
}
}
template <typename T>
void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1,
int expectedNumResults = -1) {
bool variadicIn = expectedNumOperands == -1;
bool variadicOut = expectedNumResults == -1;
std::vector<mlir::Value> inputs;
for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
}
}
if (!variadicIn)
for (auto i = inputs.size(); i < expectedNumOperands; i++)
inputs.emplace_back(none_);
std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) {
outputTypes.push_back(
mlir::UnrankedTensorType::get(builder_.getF32Type()));
}
auto attributes = ImportNodeAttributes(node);
// TODO: Handle optional inputs.
auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
for (int i = 0; i < node.output().size(); i++) {
frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
*(op.getODSResults(i).begin()));
}
}
/*!
* Special handle for Conv operations.
* c++ does not allow template specialization inside a class scope
* a specialized function is used
*/
void ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) {
// Conv has attribute dilations, kernel_shape, pads, the default value of
// which is determined by the shape of first argument. However, since the
// shape is unknown now, these attributes can be not generated auto
// dilations_attr = get_attr_ints(node, "dilations",
// std::vector<int>(inputs[0]->getType().cast<RankedTensorType>.getDims()-2,
// 1));
// attributes.push_back(dilations_attr)
// similar situation for pads, strides in AveragePool
// axes of ReduceSum, pads, strides, dilations and kernel_shape of MaxPool
// TODO: fix this after type inference
int nOps = node.input().size();
if (nOps == 2)
buildOperation<mlir::ONNXConvNoBiasOp>(node, nOps, nOut);
else
buildOperation<mlir::ONNXConvOp>(node, nOps, nOut);
}
/*!
* Special handle for MaxPool operations.
*/
void ImportNodeMaxPool(onnx::NodeProto node, int nIn, int nOut) {
int nOuts = node.output().size();
if (nOuts == 1) {
buildOperation<mlir::ONNXMaxPoolSingleOutOp>(node, nIn, nOuts);
} else {
buildOperation<mlir::ONNXMaxPoolOp>(node, nIn, nOuts);
}
}
/*!
* Special handle for BatchNormalization operations.
*/
void ImportNodeBatchNormalization(onnx::NodeProto node, int nIn, int nOut) {
int nOuts = node.output().size();
if (nOuts == 1) {
// Test mode with one output.
buildOperation<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn, nOuts);
} else {
// Training mode with four trailing optional outputs. Not handled yet.
buildOperation<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts);
}
}
/*!
* Special handle for Pad operations.
*/
void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) {
int nOps = node.input().size();
if (nOps == 2) {
buildOperation<mlir::ONNXPadConstantValueOp>(node, 2, nOut);
} else {
buildOperation<mlir::ONNXPadOp>(node, nIn, nOut);
}
}
void ImportNode(const onnx::NodeProto &node) {
llvm::StringRef opName = node.op_type();
// the following code is generated by gen_doc.py
// refer to dialect/onnx/onnx.td for details
// when the input or output of then op does not match the specification,
// the generic operator is used
// one known reeason is the optional input
#include "src/builder/op_build_table.inc"
}
/*!
* Import output tensor, by doing the following:
* - Add the type of this output tensor to a list of tensor
* types representing return types of this graph function.
* - Add this output tensor to the list of mlir::Value
* to be returned by the function representing computation graph.
* @param output onnx output tensor ValueInfoProto.
* @param ret_types a vector of tensor types representing graph's
* output tensor types.
* @param ret_vals a vector of mlir Value representing graph's
* output tensor.
*/
void ImportOutputTensor(const onnx::ValueInfoProto &output,
llvm::SmallVectorImpl<mlir::Type> &ret_types,
llvm::SmallVectorImpl<mlir::Value> &ret_vals) {
auto output_tensor_legalized_name = legalize_name(output.name());
assert(frontend_symbols_.ContainKey(output_tensor_legalized_name) &&
"Output tensor not found");
auto tensor_val =
frontend_symbols_.GetTensorByOnnxName(output_tensor_legalized_name);
ret_types.emplace_back(tensor_val.getType());
ret_vals.push_back(tensor_val);
}
void ImportGraph(const onnx::GraphProto &graph,
const std::string &name = "main_graph") {
// create a function for the graph
// TODO:
// * get name and type for the function.
// * maintain a list of the defined graph
llvm::SmallVector<mlir::Type, 4> arg_types;
// Import the input tensor types.
for (const auto &input : graph.input()) {
ImportInputTensorType(input, arg_types);
}
// TODO: import the initializer
auto funcType = builder_.getFunctionType(arg_types, {});
auto mainFunc =
mlir::FuncOp::create(UnknownLoc(), name, funcType, /* attrs = */ {});
auto entryPoint = mlir::ONNXEntryPointOp::create(
UnknownLoc(), mainFunc, /*numInputs=*/graph.input().size(),
/*numOutputs=*/graph.output().size());
auto &entryBlock = *mainFunc.addEntryBlock();
builder_.setInsertionPointToStart(&entryBlock);
module_.push_back(mainFunc);
module_.push_back(entryPoint);
for (auto it : llvm::zip(graph.input(), entryBlock.getArguments())) {
ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it));
}
// Create a NoneTyped constant.
none_ =
builder_.create<mlir::ConstantOp>(UnknownLoc(), builder_.getUnitAttr());
// Import nodes in the graph.
for (const auto &item : graph.node()) {
ImportNode(item);
}
llvm::SmallVector<mlir::Type, 4> ret_types;
llvm::SmallVector<mlir::Value, 4> ret_vals;
// Import the output tensors
for (const auto &output : graph.output()) {
ImportOutputTensor(output, ret_types, ret_vals);
}
// Create a return operation to return all ONNX output tensors.
builder_.create<mlir::ReturnOp>(UnknownLoc(), ret_vals);
// Update main function signature to reflect types of newly imported
// output tensors.
funcType = builder_.getFunctionType(arg_types, ret_types);
mainFunc.setType(funcType);
}
}; // FrontendGenImpl class
} // namespace
} // namespace onnf
namespace onnf {
mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model) {
mlir::MLIRContext context;
FrontendGenImpl myONNXGen(context);
auto module = myONNXGen.ImportONNXModel(model);
return module;
}
void ImportFrontendModelFile(std::string model_fname,
mlir::MLIRContext &context,
mlir::OwningModuleRef &module) {
onnx::ModelProto model;
std::fstream input(model_fname, std::ios::in | std::ios::binary);
auto parse_success = model.ParseFromIstream(&input);
assert(parse_success && "Onnx Model Parsing Failed.");
FrontendGenImpl myONNXGen(context);
module = myONNXGen.ImportONNXModel(model);
}
} // namespace onnf