2020-03-19 16:48:09 +08:00
|
|
|
//===--------- FrontendDialectTransformer.cpp - MLIR Operations -----------===//
|
2019-09-30 22:29:15 +08:00
|
|
|
//
|
2020-01-22 10:36:21 +08:00
|
|
|
// Copyright 2019 The IBM Research Authors.
|
2019-09-30 22:29:15 +08:00
|
|
|
//
|
|
|
|
// =============================================================================
|
|
|
|
//
|
2019-11-02 05:09:48 +08:00
|
|
|
// This file transforms the input to available MLIR dialects that can represent
|
|
|
|
// the operations of the model. Models use the ONNX dialect and any other
|
|
|
|
// extension dialects that comprise the the operations not supported or covered
|
|
|
|
// by the ONNX specification.
|
|
|
|
//
|
|
|
|
// A `frontend` placeholder dialect is used to encode operations that are not
|
|
|
|
// covered by any existing dialects.
|
|
|
|
//
|
2019-09-30 22:29:15 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-05-14 11:05:02 +08:00
|
|
|
#include <type_traits>
|
2020-01-22 10:36:21 +08:00
|
|
|
// Using backported variant.
|
|
|
|
// bstd = backported standard library.
|
|
|
|
#include <mpark/variant.hpp>
|
|
|
|
namespace bstd = mpark;
|
2019-09-30 22:29:15 +08:00
|
|
|
|
2020-03-19 16:48:09 +08:00
|
|
|
#include "FrontendDialectTransformer.hpp"
|
2019-11-19 08:37:58 +08:00
|
|
|
|
2020-03-17 21:16:33 +08:00
|
|
|
namespace onnx_mlir {
|
2019-09-30 22:29:15 +08:00
|
|
|
namespace {
|
|
|
|
|
2020-03-11 02:46:35 +08:00
|
|
|
/*!
|
|
|
|
* The list of tensors initialized by the ONNX model.
|
|
|
|
*/
|
|
|
|
InitializedTensorMapping initializedTensors;
|
2019-10-07 10:32:10 +08:00
|
|
|
|
2019-11-02 05:09:48 +08:00
|
|
|
class FrontendGenImpl {
|
2019-12-21 14:58:23 +08:00
|
|
|
public:
|
|
|
|
FrontendGenImpl(mlir::MLIRContext &context)
|
2019-10-07 10:32:10 +08:00
|
|
|
: context_(context), builder_(&context) {
|
|
|
|
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
|
|
|
}
|
2019-09-30 22:29:15 +08:00
|
|
|
|
2019-11-02 05:09:48 +08:00
|
|
|
mlir::ModuleOp ImportONNXModel(onnx::ModelProto model) {
|
2019-10-07 10:32:10 +08:00
|
|
|
ImportGraph(model.graph());
|
|
|
|
return module_;
|
2019-09-30 22:29:15 +08:00
|
|
|
}
|
|
|
|
|
2019-12-21 14:58:23 +08:00
|
|
|
private:
|
|
|
|
mlir::MLIRContext &context_;
|
2019-10-07 10:32:10 +08:00
|
|
|
mlir::ModuleOp module_;
|
|
|
|
mlir::OpBuilder builder_;
|
2020-02-24 23:46:48 +08:00
|
|
|
mlir::Value none_;
|
2019-10-07 10:32:10 +08:00
|
|
|
// mapping between string name and symbol
|
2020-03-17 21:16:33 +08:00
|
|
|
OnnxMlirSymbolMapping frontend_symbols_;
|
2019-09-30 22:29:15 +08:00
|
|
|
|
2019-10-08 07:47:46 +08:00
|
|
|
mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); }
|
2019-10-07 10:32:10 +08:00
|
|
|
|
2019-11-06 06:03:15 +08:00
|
|
|
// Convert type to MLIR type.
|
|
|
|
// A complete list of types can be found in:
|
2020-03-17 21:16:33 +08:00
|
|
|
// <onnx-mlir-build-folder>/third_party/onnx/onnx/onnx.pb.h
|
2020-01-22 10:36:21 +08:00
|
|
|
mlir::Type convertONNXTypeToMLIRType(onnx::TensorProto_DataType onnxType) {
|
|
|
|
switch (onnxType) {
|
|
|
|
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16:
|
|
|
|
return builder_.getF16Type();
|
|
|
|
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
|
|
|
|
return builder_.getF32Type();
|
|
|
|
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE:
|
|
|
|
return builder_.getF64Type();
|
|
|
|
case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
|
|
|
|
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
|
2020-05-14 11:05:02 +08:00
|
|
|
return builder_.getIntegerType(/*width=*/8);
|
2020-01-22 10:36:21 +08:00
|
|
|
case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
|
|
|
|
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
|
2020-05-14 11:05:02 +08:00
|
|
|
return builder_.getIntegerType(/*width=*/16);
|
2020-01-22 10:36:21 +08:00
|
|
|
case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
|
|
|
|
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
|
2020-05-14 11:05:02 +08:00
|
|
|
return builder_.getIntegerType(/*width=*/32);
|
2020-01-22 10:36:21 +08:00
|
|
|
case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
|
|
|
|
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
|
2020-05-14 11:05:02 +08:00
|
|
|
return builder_.getIntegerType(/*width=*/64);
|
2020-01-22 10:36:21 +08:00
|
|
|
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
|
|
|
|
return builder_.getI1Type();
|
2020-05-14 11:05:02 +08:00
|
|
|
|
2020-01-22 10:36:21 +08:00
|
|
|
case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
|
|
|
|
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
|
|
|
|
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
|
|
|
|
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
|
2020-04-19 22:11:24 +08:00
|
|
|
default:
|
2020-01-22 10:36:21 +08:00
|
|
|
assert(false && "Unsupported data type encountered.");
|
|
|
|
return nullptr;
|
2019-11-06 06:03:15 +08:00
|
|
|
}
|
2019-10-07 10:32:10 +08:00
|
|
|
}
|
|
|
|
|
2019-11-19 08:37:58 +08:00
|
|
|
/*!
|
|
|
|
* Import an onnx input tensor type by determining and recording its type
|
|
|
|
* in a list of input tensor mlir types.
|
|
|
|
* @param input onnx input tensor ValueInfoProto.
|
|
|
|
* @param arg_types list of mlir types representing types of graph input.
|
|
|
|
*/
|
2020-03-11 02:46:35 +08:00
|
|
|
mlir::Type ImportInputTensorType(const onnx::ValueInfoProto &input) {
|
2019-10-07 10:32:10 +08:00
|
|
|
std::vector<int64_t> dims;
|
|
|
|
auto shape_proto = input.type().tensor_type().shape();
|
|
|
|
auto input_tensor_legalized_name = legalize_name(input.name());
|
|
|
|
for (int i = 0; i < shape_proto.dim_size(); i++) {
|
|
|
|
if (shape_proto.dim()[i].dim_value()) {
|
|
|
|
int dim_numeric_size = shape_proto.dim()[i].dim_value();
|
2020-01-22 10:36:21 +08:00
|
|
|
assert(dim_numeric_size != 0 &&
|
|
|
|
"Parsed an input tensor with a dimension size of zero");
|
2019-10-07 10:32:10 +08:00
|
|
|
if (dim_numeric_size > 0) {
|
|
|
|
dims.push_back(dim_numeric_size);
|
2019-12-21 14:58:23 +08:00
|
|
|
} else { // If dim_value < 0, then dim is parametric.
|
|
|
|
// TODO Verify the unknown dim size in MLIR
|
2019-10-08 07:47:46 +08:00
|
|
|
dims.push_back(-1);
|
2019-10-07 10:32:10 +08:00
|
|
|
}
|
|
|
|
} else {
|
2019-10-08 07:47:46 +08:00
|
|
|
// TODO How to represent variable length
|
|
|
|
dims.push_back(-1);
|
2019-10-07 10:32:10 +08:00
|
|
|
}
|
|
|
|
}
|
2019-11-19 08:37:58 +08:00
|
|
|
|
2020-02-25 13:04:15 +08:00
|
|
|
auto elementOnnxType =
|
|
|
|
(onnx::TensorProto_DataType)input.type().tensor_type().elem_type();
|
|
|
|
mlir::Type elementType = convertONNXTypeToMLIRType(elementOnnxType);
|
2019-11-19 08:37:58 +08:00
|
|
|
llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
|
2020-03-11 02:46:35 +08:00
|
|
|
return mlir::RankedTensorType::get(tensor_dims, elementType);
|
2019-11-19 08:37:58 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/*!
|
|
|
|
* Import a input tensor symbol by recording a new entry in frontend_symbols_
|
2019-12-31 11:42:13 +08:00
|
|
|
* recording the mapping between legalized onnx tensor name and mlir::Value
|
2019-11-19 08:37:58 +08:00
|
|
|
* for further lookup in computation node importing.
|
|
|
|
* @param input onnx input tensor ValueInfoProto.
|
|
|
|
* @param symbol mlir input argument.
|
|
|
|
*/
|
2020-05-13 22:37:51 +08:00
|
|
|
void ImportInputTensorSymbol(
|
|
|
|
const onnx::ValueInfoProto &input, mlir::Value symbol) {
|
2019-11-19 08:37:58 +08:00
|
|
|
auto input_tensor_legalized_name = legalize_name(input.name());
|
2020-01-22 10:36:21 +08:00
|
|
|
assert(!frontend_symbols_.ContainKey(input_tensor_legalized_name) &&
|
|
|
|
"Found duplicate legalized input tensor names.");
|
2019-11-19 08:37:58 +08:00
|
|
|
frontend_symbols_.AddMapping(input_tensor_legalized_name, symbol);
|
2019-10-07 10:32:10 +08:00
|
|
|
}
|
|
|
|
|
2020-05-14 11:05:02 +08:00
|
|
|
mlir::NamedAttribute convertOnnxAttributeProtoToMlirNamedAttribute(
|
|
|
|
onnx::AttributeProto &attr) {
|
|
|
|
mlir::Attribute mlirAttr;
|
2020-01-22 10:36:21 +08:00
|
|
|
switch (attr.type()) {
|
|
|
|
case onnx::AttributeProto::FLOAT:
|
2020-05-14 11:05:02 +08:00
|
|
|
mlirAttr = builder_.getF32FloatAttr(attr.f());
|
|
|
|
break;
|
2020-01-22 10:36:21 +08:00
|
|
|
case onnx::AttributeProto::INT:
|
2020-05-14 11:05:02 +08:00
|
|
|
mlirAttr = builder_.getI64IntegerAttr(attr.i());
|
|
|
|
break;
|
2020-01-22 10:36:21 +08:00
|
|
|
case onnx::AttributeProto::STRING:
|
2020-05-14 11:05:02 +08:00
|
|
|
mlirAttr = builder_.getStringAttr(attr.s());
|
|
|
|
break;
|
2020-01-22 10:36:21 +08:00
|
|
|
case onnx::AttributeProto::FLOATS:
|
2020-05-14 11:05:02 +08:00
|
|
|
mlirAttr = builder_.getF32ArrayAttr(
|
|
|
|
llvm::makeArrayRef(attr.floats().begin(), attr.floats().end()));
|
|
|
|
break;
|
2020-01-22 10:36:21 +08:00
|
|
|
case onnx::AttributeProto::INTS:
|
2020-05-14 11:05:02 +08:00
|
|
|
mlirAttr = builder_.getI64ArrayAttr(
|
|
|
|
llvm::makeArrayRef(attr.ints().begin(), attr.ints().end()));
|
|
|
|
break;
|
|
|
|
case onnx::AttributeProto::TENSOR:
|
|
|
|
mlirAttr = onnxTensorProtoToDenseElmAttr(builder_, attr.t());
|
|
|
|
break;
|
2020-01-22 10:36:21 +08:00
|
|
|
default:
|
2020-05-14 11:05:02 +08:00
|
|
|
llvm_unreachable("datatype for attribute is not implemented");
|
2020-01-22 10:36:21 +08:00
|
|
|
break;
|
|
|
|
}
|
2020-05-14 11:05:02 +08:00
|
|
|
return builder_.getNamedAttr(attr.name(), mlirAttr);
|
2019-12-21 14:58:23 +08:00
|
|
|
}
|
|
|
|
|
2020-05-13 22:37:51 +08:00
|
|
|
std::vector<mlir::NamedAttribute> ImportNodeAttributes(
|
|
|
|
const onnx::NodeProto &node) {
|
2020-01-22 10:36:21 +08:00
|
|
|
std::vector<mlir::NamedAttribute> attributes;
|
|
|
|
for (int i = 0; i < node.attribute_size(); ++i) {
|
|
|
|
auto attr = node.attribute(i);
|
2020-05-14 11:05:02 +08:00
|
|
|
attributes.push_back(convertOnnxAttributeProtoToMlirNamedAttribute(attr));
|
2020-01-22 10:36:21 +08:00
|
|
|
}
|
|
|
|
return attributes;
|
2019-12-21 14:58:23 +08:00
|
|
|
}
|
|
|
|
|
2020-01-22 10:36:21 +08:00
|
|
|
void ImportNodeGeneric(const onnx::NodeProto &node) {
|
2019-12-31 11:42:13 +08:00
|
|
|
std::vector<mlir::Value> inputs;
|
2020-01-22 10:36:21 +08:00
|
|
|
for (const auto &item : node.input()) {
|
2019-12-21 14:58:23 +08:00
|
|
|
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
|
|
|
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
|
|
|
}
|
|
|
|
}
|
2019-11-02 05:09:48 +08:00
|
|
|
mlir::OperationState result(UnknownLoc(), "frontend." + node.op_type());
|
2019-10-07 10:32:10 +08:00
|
|
|
for (auto item : node.output()) {
|
2019-10-30 01:57:56 +08:00
|
|
|
result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
2019-10-07 10:32:10 +08:00
|
|
|
}
|
|
|
|
result.addOperands(inputs);
|
2019-10-08 07:47:46 +08:00
|
|
|
auto op = builder_.createOperation(result);
|
|
|
|
for (int i = 0; i < node.output().size(); i++) {
|
2019-10-09 07:25:59 +08:00
|
|
|
auto r = op->getResult(i);
|
2019-11-02 05:09:48 +08:00
|
|
|
frontend_symbols_.AddMapping(legalize_name(node.output()[i]), r);
|
2019-10-07 10:32:10 +08:00
|
|
|
}
|
2019-12-21 14:58:23 +08:00
|
|
|
}
|
|
|
|
|
2020-05-22 10:03:16 +08:00
|
|
|
#define MAX_TYPE 20
|
|
|
|
// itblgen_types = ('I1', 'I8', 'I16', 'I32', 'I64', 'BF16', 'F16', 'F32',
|
|
|
|
// 'F64', 'Complex<F32>', 'Complex<F64>' )
|
|
|
|
mlir::Type buildTypeFromIndex(int index) {
|
|
|
|
switch (index) {
|
|
|
|
case 0:
|
|
|
|
return builder_.getI1Type();
|
|
|
|
case 1:
|
|
|
|
return builder_.getIntegerType(8);
|
|
|
|
case 2:
|
|
|
|
return builder_.getIntegerType(16);
|
|
|
|
case 3:
|
|
|
|
return builder_.getIntegerType(32);
|
|
|
|
case 4:
|
|
|
|
return builder_.getIntegerType(64);
|
|
|
|
case 5:
|
|
|
|
return builder_.getBF16Type();
|
|
|
|
case 6:
|
|
|
|
return builder_.getF16Type();
|
|
|
|
case 7:
|
|
|
|
return builder_.getF32Type();
|
|
|
|
case 8:
|
|
|
|
return builder_.getF64Type();
|
|
|
|
case 9: {
|
|
|
|
std::vector<mlir::Type> typeTuple(2);
|
|
|
|
typeTuple.push_back(builder_.getF32Type());
|
|
|
|
typeTuple.push_back(builder_.getF32Type());
|
|
|
|
return builder_.getTupleType(llvm::ArrayRef<mlir::Type>(typeTuple));
|
|
|
|
}
|
|
|
|
case 10: {
|
|
|
|
std::vector<mlir::Type> typeTuple(2);
|
|
|
|
typeTuple.push_back(builder_.getF64Type());
|
|
|
|
typeTuple.push_back(builder_.getF64Type());
|
|
|
|
return builder_.getTupleType(llvm::ArrayRef<mlir::Type>(typeTuple));
|
|
|
|
}
|
|
|
|
default:
|
|
|
|
assert(false && "Unsupported type index encountered.");
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-12-21 14:58:23 +08:00
|
|
|
template <typename T>
|
2020-03-11 02:46:35 +08:00
|
|
|
void buildOutputAndOperation(const onnx::NodeProto &node,
|
|
|
|
std::vector<mlir::Value> inputs, int expectedNumOperands,
|
|
|
|
int expectedNumResults) {
|
2020-02-24 23:46:48 +08:00
|
|
|
bool variadicIn = expectedNumOperands == -1;
|
|
|
|
bool variadicOut = expectedNumResults == -1;
|
2019-12-21 14:58:23 +08:00
|
|
|
|
2020-03-31 09:21:18 +08:00
|
|
|
// In ONNX, there are two ways to leave an optional input or output
|
|
|
|
// unspecified: the first, available only for trailing inputs and outputs,
|
|
|
|
// is to simply not provide that input; the second method is to use an empty
|
|
|
|
// string in place of an input or output name.
|
|
|
|
//
|
|
|
|
// Here, we import optional inputs and outputs as NoneType.
|
|
|
|
|
|
|
|
// Trailing optional inputs.
|
2020-02-24 23:46:48 +08:00
|
|
|
if (!variadicIn)
|
|
|
|
for (auto i = inputs.size(); i < expectedNumOperands; i++)
|
|
|
|
inputs.emplace_back(none_);
|
2019-12-21 14:58:23 +08:00
|
|
|
|
|
|
|
std::vector<mlir::Type> outputTypes;
|
2020-05-22 10:03:16 +08:00
|
|
|
|
|
|
|
// Use the type map to determine the data type of output.
|
|
|
|
std::vector<int> outputMap = T::getTypeMap();
|
|
|
|
for (auto i = 0; i < node.output().size(); i++) {
|
2020-03-31 09:21:18 +08:00
|
|
|
// Optional outputs using empty string.
|
2020-05-22 10:03:16 +08:00
|
|
|
if (node.output()[i].empty()) {
|
2020-03-31 09:21:18 +08:00
|
|
|
outputTypes.emplace_back(builder_.getNoneType());
|
2020-05-22 10:03:16 +08:00
|
|
|
} else {
|
|
|
|
if (i < outputMap.size() && outputMap[i] >= MAX_TYPE) {
|
|
|
|
// Mapping gives a connection with an input.
|
|
|
|
mlir::Type inputType = inputs[outputMap[i] - MAX_TYPE].getType();
|
|
|
|
if (inputType.isa<mlir::TensorType>()) {
|
|
|
|
auto elementType =
|
|
|
|
inputType.cast<mlir::TensorType>().getElementType();
|
|
|
|
auto outType = mlir::UnrankedTensorType::get(elementType);
|
|
|
|
outputTypes.emplace_back(outType);
|
|
|
|
} else {
|
|
|
|
outputTypes.push_back(inputType);
|
|
|
|
}
|
|
|
|
} else if (i < outputMap.size() && outputMap[i] != -1) {
|
|
|
|
// Mapping gives a direct type.
|
|
|
|
auto elementType = buildTypeFromIndex(outputMap[i]);
|
|
|
|
auto outType = mlir::UnrankedTensorType::get(elementType);
|
|
|
|
outputTypes.emplace_back(outType);
|
|
|
|
} else {
|
|
|
|
outputTypes.emplace_back(builder_.getNoneType());
|
|
|
|
}
|
|
|
|
}
|
2019-12-21 14:58:23 +08:00
|
|
|
}
|
2020-03-31 09:21:18 +08:00
|
|
|
// Trailing optional outputs.
|
|
|
|
if (!variadicOut)
|
|
|
|
for (int i = node.output().size(); i < expectedNumResults; ++i)
|
|
|
|
outputTypes.emplace_back(builder_.getNoneType());
|
2019-10-08 07:47:46 +08:00
|
|
|
|
2020-01-27 23:09:14 +08:00
|
|
|
auto attributes = ImportNodeAttributes(node);
|
2019-12-21 14:58:23 +08:00
|
|
|
|
2020-02-24 23:46:48 +08:00
|
|
|
// TODO: Handle optional inputs.
|
|
|
|
auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
|
|
|
|
for (int i = 0; i < node.output().size(); i++) {
|
2020-05-13 22:37:51 +08:00
|
|
|
frontend_symbols_.AddMapping(
|
|
|
|
legalize_name(node.output()[i]), *(op.getODSResults(i).begin()));
|
2019-12-21 14:58:23 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-03-11 02:46:35 +08:00
|
|
|
template <typename T>
|
2020-05-22 10:03:16 +08:00
|
|
|
void buildOperation(const onnx::NodeProto &node) {
|
2020-03-11 02:46:35 +08:00
|
|
|
std::vector<mlir::Value> inputs;
|
2020-05-22 10:03:16 +08:00
|
|
|
int expectedNumOperands = T::getNumberOfOperands();
|
|
|
|
int expectedNumResults = T::getNumberOfResults();
|
2020-03-11 02:46:35 +08:00
|
|
|
for (const auto &item : node.input())
|
2020-03-16 23:17:28 +08:00
|
|
|
if (initializedTensors.ContainKey(legalize_name(item))) {
|
|
|
|
inputs.push_back(initializedTensors.EmitInitializerForInputTensor(
|
2020-05-13 22:37:51 +08:00
|
|
|
UnknownLoc(), builder_, legalize_name(item)));
|
2020-03-16 23:17:28 +08:00
|
|
|
} else if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
2020-03-11 02:46:35 +08:00
|
|
|
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
2020-03-16 23:17:28 +08:00
|
|
|
}
|
2020-03-11 02:46:35 +08:00
|
|
|
|
2020-05-13 22:37:51 +08:00
|
|
|
buildOutputAndOperation<T>(
|
|
|
|
node, inputs, expectedNumOperands, expectedNumResults);
|
2020-03-11 02:46:35 +08:00
|
|
|
}
|
|
|
|
|
2020-05-22 10:03:16 +08:00
|
|
|
void ImportNodeReshape(onnx::NodeProto node) {
|
|
|
|
int expectedNumOperands = mlir::ONNXReshapeOp::getNumberOfOperands();
|
|
|
|
int expectedNumResults = mlir::ONNXReshapeOp::getNumberOfResults();
|
2020-03-11 02:46:35 +08:00
|
|
|
std::vector<mlir::Value> inputs;
|
|
|
|
std::string item;
|
|
|
|
for (int i = 0; i < node.input().size(); ++i) {
|
|
|
|
item = node.input()[i];
|
|
|
|
// For the second argument, check if there exists an initializer.
|
2020-03-16 23:17:28 +08:00
|
|
|
if (initializedTensors.ContainKey(legalize_name(item))) {
|
2020-05-13 22:37:51 +08:00
|
|
|
inputs.push_back(initializedTensors.EmitInitializerForInputTensor(
|
|
|
|
UnknownLoc(), builder_, legalize_name(item)));
|
2020-03-11 02:46:35 +08:00
|
|
|
} else if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
|
|
|
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-05-22 10:03:16 +08:00
|
|
|
buildOutputAndOperation<mlir::ONNXReshapeOp>(
|
|
|
|
node, inputs, expectedNumOperands, expectedNumResults);
|
2020-03-11 02:46:35 +08:00
|
|
|
}
|
|
|
|
|
2020-01-16 04:16:45 +08:00
|
|
|
/*!
|
|
|
|
* Special handle for MaxPool operations.
|
|
|
|
*/
|
2020-05-22 10:03:16 +08:00
|
|
|
void ImportNodeMaxPool(onnx::NodeProto node) {
|
2020-01-16 04:16:45 +08:00
|
|
|
int nOuts = node.output().size();
|
|
|
|
if (nOuts == 1) {
|
2020-05-22 10:03:16 +08:00
|
|
|
buildOperation<mlir::ONNXMaxPoolSingleOutOp>(node);
|
2020-01-16 04:16:45 +08:00
|
|
|
} else {
|
2020-05-22 10:03:16 +08:00
|
|
|
buildOperation<mlir::ONNXMaxPoolOp>(node);
|
2020-01-16 04:16:45 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-02-21 00:45:40 +08:00
|
|
|
/*!
|
|
|
|
* Special handle for BatchNormalization operations.
|
|
|
|
*/
|
2020-05-22 10:03:16 +08:00
|
|
|
void ImportNodeBatchNormalization(onnx::NodeProto node) {
|
2020-02-21 00:45:40 +08:00
|
|
|
int nOuts = node.output().size();
|
|
|
|
if (nOuts == 1) {
|
|
|
|
// Test mode with one output.
|
2020-05-22 10:03:16 +08:00
|
|
|
buildOperation<mlir::ONNXBatchNormalizationTestModeOp>(node);
|
2020-02-21 00:45:40 +08:00
|
|
|
} else {
|
|
|
|
// Training mode with four trailing optional outputs. Not handled yet.
|
2020-05-22 10:03:16 +08:00
|
|
|
buildOperation<mlir::ONNXBatchNormalizationOp>(node);
|
2020-02-08 02:45:37 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-02-12 04:32:01 +08:00
|
|
|
/*!
|
|
|
|
* Special handle for Pad operations.
|
|
|
|
*/
|
2020-05-22 10:03:16 +08:00
|
|
|
void ImportNodePad(onnx::NodeProto node) {
|
2020-05-15 13:19:28 +08:00
|
|
|
|
2020-02-12 04:32:01 +08:00
|
|
|
int nOps = node.input().size();
|
|
|
|
if (nOps == 2) {
|
2020-05-15 13:19:28 +08:00
|
|
|
llvm::SmallVector<int64_t, 2> dims;
|
|
|
|
dims.push_back(1);
|
|
|
|
llvm::SmallVector<float, 2> values;
|
|
|
|
values.push_back(0.);
|
|
|
|
auto elementType = builder_.getF32Type();
|
|
|
|
llvm::ArrayRef<int64_t> tensorDims(dims.data(), dims.size());
|
|
|
|
auto tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
|
|
|
|
auto constantDenseAttribute =
|
|
|
|
mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values));
|
|
|
|
|
|
|
|
// Use the special builder defined in ONNXOp.td.inc.
|
|
|
|
auto constantOp = builder_.create<mlir::ONNXConstantOp>(
|
|
|
|
UnknownLoc(), mlir::Attribute(), constantDenseAttribute);
|
|
|
|
mlir::Value constantResult = *(constantOp.getODSResults(0).begin());
|
|
|
|
std::vector<mlir::Value> inputs;
|
|
|
|
for (const auto &item : node.input())
|
|
|
|
if (initializedTensors.ContainKey(legalize_name(item))) {
|
|
|
|
inputs.push_back(initializedTensors.EmitInitializerForInputTensor(
|
|
|
|
UnknownLoc(), builder_, legalize_name(item)));
|
|
|
|
} else if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
|
|
|
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
|
|
|
}
|
|
|
|
inputs.push_back(constantResult);
|
|
|
|
|
2020-05-22 10:03:16 +08:00
|
|
|
int nIn = mlir::ONNXPadOp::getNumberOfOperands();
|
|
|
|
int nOut = mlir::ONNXPadOp::getNumberOfResults();
|
2020-05-15 13:19:28 +08:00
|
|
|
buildOutputAndOperation<mlir::ONNXPadOp>(node, inputs, nIn, nOut);
|
2020-02-12 04:32:01 +08:00
|
|
|
} else {
|
2020-05-22 10:03:16 +08:00
|
|
|
buildOperation<mlir::ONNXPadOp>(node);
|
2020-02-12 04:32:01 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-01-22 10:36:21 +08:00
|
|
|
void ImportNode(const onnx::NodeProto &node) {
|
2020-02-24 23:46:48 +08:00
|
|
|
llvm::StringRef opName = node.op_type();
|
2019-12-21 14:58:23 +08:00
|
|
|
|
|
|
|
// the following code is generated by gen_doc.py
|
2020-03-19 16:48:09 +08:00
|
|
|
// refer to Dialect/ONNX/ONNXOps.td for details
|
2019-12-21 14:58:23 +08:00
|
|
|
// when the input or output of then op does not match the specification,
|
|
|
|
// the generic operator is used
|
|
|
|
// one known reeason is the optional input
|
|
|
|
|
2020-03-19 16:48:09 +08:00
|
|
|
#include "src/Builder/OpBuildTable.inc"
|
2020-05-13 06:43:44 +08:00
|
|
|
#if INCLUDE_ONNX_ML == 1
|
|
|
|
#include "src/Builder/MLOpBuildTable.inc"
|
|
|
|
#endif
|
2019-10-07 10:32:10 +08:00
|
|
|
}
|
|
|
|
|
2019-11-19 08:37:58 +08:00
|
|
|
/*!
|
|
|
|
* Import output tensor, by doing the following:
|
|
|
|
* - Add the type of this output tensor to a list of tensor
|
|
|
|
* types representing return types of this graph function.
|
2019-12-31 11:42:13 +08:00
|
|
|
* - Add this output tensor to the list of mlir::Value
|
2019-11-19 08:37:58 +08:00
|
|
|
* to be returned by the function representing computation graph.
|
|
|
|
* @param output onnx output tensor ValueInfoProto.
|
|
|
|
* @param ret_types a vector of tensor types representing graph's
|
|
|
|
* output tensor types.
|
2019-12-31 11:42:13 +08:00
|
|
|
* @param ret_vals a vector of mlir Value representing graph's
|
2019-11-19 08:37:58 +08:00
|
|
|
* output tensor.
|
|
|
|
*/
|
2019-12-21 14:58:23 +08:00
|
|
|
void ImportOutputTensor(const onnx::ValueInfoProto &output,
|
2020-05-13 22:37:51 +08:00
|
|
|
llvm::SmallVectorImpl<mlir::Type> &ret_types,
|
|
|
|
llvm::SmallVectorImpl<mlir::Value> &ret_vals) {
|
2019-11-15 00:11:05 +08:00
|
|
|
auto output_tensor_legalized_name = legalize_name(output.name());
|
2020-01-22 10:36:21 +08:00
|
|
|
assert(frontend_symbols_.ContainKey(output_tensor_legalized_name) &&
|
|
|
|
"Output tensor not found");
|
2019-11-19 08:37:58 +08:00
|
|
|
|
|
|
|
auto tensor_val =
|
|
|
|
frontend_symbols_.GetTensorByOnnxName(output_tensor_legalized_name);
|
2020-01-14 01:21:29 +08:00
|
|
|
ret_types.emplace_back(tensor_val.getType());
|
2019-11-19 08:37:58 +08:00
|
|
|
ret_vals.push_back(tensor_val);
|
2019-10-07 10:32:10 +08:00
|
|
|
}
|
|
|
|
|
2020-05-13 22:37:51 +08:00
|
|
|
void ImportGraph(
|
|
|
|
const onnx::GraphProto &graph, const std::string &name = "main_graph") {
|
2020-03-11 02:46:35 +08:00
|
|
|
// Maintain a mapping between the parameter and its initializer.
|
|
|
|
for (auto initializer : graph.initializer()) {
|
|
|
|
auto name = initializer.name();
|
|
|
|
initializedTensors.AddMapping(legalize_name(name), initializer);
|
|
|
|
}
|
|
|
|
|
2019-10-08 07:47:46 +08:00
|
|
|
// create a function for the graph
|
|
|
|
// TODO:
|
|
|
|
// * get name and type for the function.
|
2019-10-07 10:32:10 +08:00
|
|
|
// * maintain a list of the defined graph
|
|
|
|
llvm::SmallVector<mlir::Type, 4> arg_types;
|
2019-11-19 08:37:58 +08:00
|
|
|
|
2020-03-16 23:17:28 +08:00
|
|
|
// Import the input tensor types that are not constant and not initialized.
|
2020-03-11 02:46:35 +08:00
|
|
|
for (const auto &input : graph.input())
|
2020-03-16 23:17:28 +08:00
|
|
|
if (!initializedTensors.ContainKey(legalize_name(input.name())))
|
|
|
|
arg_types.emplace_back(ImportInputTensorType(input));
|
2019-10-07 10:32:10 +08:00
|
|
|
|
2020-03-11 02:46:35 +08:00
|
|
|
// Create the main function.
|
2019-12-22 13:25:02 +08:00
|
|
|
auto funcType = builder_.getFunctionType(arg_types, {});
|
|
|
|
auto mainFunc =
|
|
|
|
mlir::FuncOp::create(UnknownLoc(), name, funcType, /* attrs = */ {});
|
2020-03-11 02:46:35 +08:00
|
|
|
|
|
|
|
// Emit the entry point operation which specifies the number of user
|
|
|
|
// inputs and outputs.
|
2020-05-13 22:37:51 +08:00
|
|
|
auto entryPoint = mlir::ONNXEntryPointOp::create(UnknownLoc(), mainFunc,
|
2020-03-11 02:46:35 +08:00
|
|
|
/*numInputs=*/graph.input().size() - graph.initializer().size(),
|
2019-12-22 13:25:02 +08:00
|
|
|
/*numOutputs=*/graph.output().size());
|
|
|
|
|
2020-03-11 02:46:35 +08:00
|
|
|
// Get the entru block inside the main function and set the insertion point
|
|
|
|
// to it.
|
2019-12-22 13:25:02 +08:00
|
|
|
auto &entryBlock = *mainFunc.addEntryBlock();
|
2019-11-19 08:37:58 +08:00
|
|
|
builder_.setInsertionPointToStart(&entryBlock);
|
2019-12-22 13:25:02 +08:00
|
|
|
|
|
|
|
module_.push_back(mainFunc);
|
|
|
|
module_.push_back(entryPoint);
|
2019-10-07 10:32:10 +08:00
|
|
|
|
2020-03-11 02:46:35 +08:00
|
|
|
// Map graph inputs to entry block arguments.
|
2020-04-06 23:35:17 +08:00
|
|
|
// Counter of un-initialized tensors. This counter is used to index the
|
|
|
|
// entry block arguments.
|
|
|
|
int entryBlockArgIdx = 0;
|
|
|
|
for (int i = 0; i < graph.input().size(); ++i) {
|
2020-03-16 23:17:28 +08:00
|
|
|
if (!initializedTensors.ContainKey(
|
2020-04-06 23:35:17 +08:00
|
|
|
legalize_name(graph.input()[i].name()))) {
|
|
|
|
ImportInputTensorSymbol(
|
|
|
|
graph.input()[i], entryBlock.getArguments()[entryBlockArgIdx]);
|
|
|
|
entryBlockArgIdx++;
|
|
|
|
}
|
|
|
|
}
|
2020-03-11 02:46:35 +08:00
|
|
|
|
|
|
|
// Create a NoneTyped constant to be used for optional operation inputs
|
|
|
|
// which are not used.
|
2020-05-13 22:37:51 +08:00
|
|
|
none_ =
|
|
|
|
builder_.create<mlir::ConstantOp>(UnknownLoc(), builder_.getUnitAttr());
|
2019-10-07 10:32:10 +08:00
|
|
|
|
2020-02-24 23:46:48 +08:00
|
|
|
// Import nodes in the graph.
|
|
|
|
for (const auto &item : graph.node()) {
|
2019-10-07 10:32:10 +08:00
|
|
|
ImportNode(item);
|
|
|
|
}
|
|
|
|
|
2019-11-19 08:37:58 +08:00
|
|
|
llvm::SmallVector<mlir::Type, 4> ret_types;
|
2019-12-31 11:42:13 +08:00
|
|
|
llvm::SmallVector<mlir::Value, 4> ret_vals;
|
2019-11-19 08:37:58 +08:00
|
|
|
// Import the output tensors
|
2019-12-21 14:58:23 +08:00
|
|
|
for (const auto &output : graph.output()) {
|
2019-11-19 08:37:58 +08:00
|
|
|
ImportOutputTensor(output, ret_types, ret_vals);
|
2019-10-07 10:32:10 +08:00
|
|
|
}
|
2019-11-19 08:37:58 +08:00
|
|
|
|
|
|
|
// Create a return operation to return all ONNX output tensors.
|
|
|
|
builder_.create<mlir::ReturnOp>(UnknownLoc(), ret_vals);
|
|
|
|
// Update main function signature to reflect types of newly imported
|
|
|
|
// output tensors.
|
2019-12-22 13:25:02 +08:00
|
|
|
funcType = builder_.getFunctionType(arg_types, ret_types);
|
|
|
|
mainFunc.setType(funcType);
|
2019-10-07 10:32:10 +08:00
|
|
|
}
|
2020-01-22 10:36:21 +08:00
|
|
|
}; // FrontendGenImpl class
|
|
|
|
} // namespace
|
2020-03-17 21:16:33 +08:00
|
|
|
} // namespace onnx_mlir
|
2019-09-30 22:29:15 +08:00
|
|
|
|
2020-03-17 21:16:33 +08:00
|
|
|
namespace onnx_mlir {
|
2019-10-09 07:25:59 +08:00
|
|
|
|
2019-11-08 00:42:40 +08:00
|
|
|
void ImportFrontendModelFile(std::string model_fname,
|
2020-05-13 22:37:51 +08:00
|
|
|
mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
|
2019-10-09 07:25:59 +08:00
|
|
|
onnx::ModelProto model;
|
|
|
|
std::fstream input(model_fname, std::ios::in | std::ios::binary);
|
|
|
|
|
|
|
|
auto parse_success = model.ParseFromIstream(&input);
|
2019-12-21 14:58:23 +08:00
|
|
|
assert(parse_success && "Onnx Model Parsing Failed.");
|
2019-11-02 05:09:48 +08:00
|
|
|
|
2019-11-08 00:42:40 +08:00
|
|
|
FrontendGenImpl myONNXGen(context);
|
|
|
|
module = myONNXGen.ImportONNXModel(model);
|
2019-10-09 07:25:59 +08:00
|
|
|
}
|
2020-03-17 21:16:33 +08:00
|
|
|
} // namespace onnx_mlir
|