Support importing tensor proto. (#116)

* Support importing tensor proto.

* Use signed types, use template.

* Resotre using signless integer types because Reshape faults with integer signs; using traits to declare/define attribute types.

* Simplify attribute importing logic.

* Use existing code to import TensorProto.

* nit.
This commit is contained in:
Tian Jin 2020-05-14 11:05:02 +08:00 committed by GitHub
parent adc08fb93e
commit 4dd3c809c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 60 additions and 97 deletions

View File

@ -126,52 +126,59 @@ bool InitializedTensorMapping::ContainKey(std::string name) {
}
mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor(
mlir::Location loc, mlir::OpBuilder &builder, std::string name) {
mlir::Location loc, mlir::OpBuilder &builder, const std::string &name) {
// Initializer for input.
onnx::TensorProto initializer = GetInitializedTensor(name);
// Tensor dimensions.
llvm::ArrayRef<int64_t> tensorDims(
initializer.dims().data(), initializer.dims().size());
// Emit ConstantOp and record the mapping between the input and
// the constant value.
// Create value attribute.
mlir::DenseElementsAttr constantDenseAttribute;
mlir::Type elementType;
mlir::ShapedType tensorType;
mlir::DenseElementsAttr denseElmAttr =
onnxTensorProtoToDenseElmAttr(builder, initializer);
// Create ConstantOp for dense array.
return builder.create<mlir::ONNXConstantOp>(
loc, denseElmAttr.getType(), nullptr, denseElmAttr);
}
mlir::DenseElementsAttr onnxTensorProtoToDenseElmAttr(
mlir::OpBuilder &builder, const onnx::TensorProto &initializer) {
// Tensor dimensions.
llvm::ArrayRef<int64_t> tensorDims(
initializer.dims().data(), initializer.dims().size());
mlir::DenseElementsAttr denseElmAttr;
switch (initializer.data_type()) {
case (onnx::TensorProto::FLOAT): {
const auto &arrayAttrInitializer = CreateArrayAttribute<float>(initializer);
elementType = builder.getF32Type();
tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
constantDenseAttribute = mlir::DenseElementsAttr::get(
auto elmType = builder.getF32Type();
auto tensorType = mlir::RankedTensorType::get(tensorDims, elmType);
denseElmAttr = mlir::DenseElementsAttr::get(
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
break;
}
case (onnx::TensorProto::INT32): {
const auto &arrayAttrInitializer =
CreateArrayAttribute<int32_t>(initializer);
elementType = builder.getIntegerType(32);
tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
constantDenseAttribute = mlir::DenseElementsAttr::get(
auto elmType = builder.getIntegerType(32);
auto tensorType = mlir::RankedTensorType::get(tensorDims, elmType);
denseElmAttr = mlir::DenseElementsAttr::get(
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
break;
}
case (onnx::TensorProto::INT64): {
const auto &arrayAttrInitializer =
CreateArrayAttribute<int64_t>(initializer);
elementType = builder.getIntegerType(64);
tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
constantDenseAttribute = mlir::DenseElementsAttr::get(
auto elmType = builder.getIntegerType(64);
auto tensorType = mlir::RankedTensorType::get(tensorDims, elmType);
denseElmAttr = mlir::DenseElementsAttr::get(
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
break;
}
default:
llvm_unreachable(
"Failed to import ONNX TensorProto due to unsupported data types.");
}
// Create ConstantOp for dense array.
return builder.create<mlir::ONNXConstantOp>(
loc, tensorType, nullptr, constantDenseAttribute);
return denseElmAttr;
}
} // namespace onnx_mlir

View File

@ -31,14 +31,12 @@
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/raw_ostream.h"
#include "onnx/onnx_pb.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#if INCLUDE_ONNX_ML == 1
#include "src/Dialect/MLONNX/MLONNXOps.hpp"
#endif
#include "onnx/onnx_pb.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
void replaceAll(
@ -88,7 +86,7 @@ struct InitializedTensorMapping {
// argument to operations such as Reshape and will enable other
// optimizations such as constant folding.
mlir::Value EmitInitializerForInputTensor(
mlir::Location loc, mlir::OpBuilder &builder, std::string name);
mlir::Location loc, mlir::OpBuilder &builder, const std::string &name);
// Get initialized tensor.
onnx::TensorProto &GetInitializedTensor(std::string name) {
@ -103,4 +101,7 @@ private:
std::map<std::string, onnx::TensorProto> nameToInitializedTensor;
};
mlir::DenseElementsAttr onnxTensorProtoToDenseElmAttr(
mlir::OpBuilder &builder, const onnx::TensorProto &initializer);
} // namespace onnx_mlir

View File

@ -14,6 +14,7 @@
//
//===----------------------------------------------------------------------===//
#include <type_traits>
// Using backported variant.
// bstd = backported standard library.
#include <mpark/variant.hpp>
@ -64,18 +65,19 @@ private:
return builder_.getF64Type();
case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
return builder_.getIntegerType(8);
return builder_.getIntegerType(/*width=*/8);
case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
return builder_.getIntegerType(16);
return builder_.getIntegerType(/*width=*/16);
case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
return builder_.getIntegerType(32);
return builder_.getIntegerType(/*width=*/32);
case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
return builder_.getIntegerType(64);
return builder_.getIntegerType(/*width=*/64);
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
return builder_.getI1Type();
case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
@ -135,81 +137,35 @@ private:
frontend_symbols_.AddMapping(input_tensor_legalized_name, symbol);
}
typedef bstd::variant<int64_t, std::vector<int64_t>, float,
std::vector<float>, std::string, std::vector<std::string>>
AttrValueType;
struct ONNXAttrVisitor {
ONNXAttrVisitor(std::string name, mlir::OpBuilder &builder)
: _builder(builder), _name(std::move(name)) {}
// Op builder.
mlir::OpBuilder &_builder;
// Name of the attribute being inspected.
std::string _name;
mlir::NamedAttribute operator()(int64_t const &r) {
auto val = _builder.getI64IntegerAttr(r);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(std::vector<int64_t> const &ints) {
auto val = _builder.getI64ArrayAttr(ints);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(float const &r) {
auto val = _builder.getF32FloatAttr(r);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(std::vector<float> const &floats) {
auto val = _builder.getF32ArrayAttr(floats);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(std::string const &s) {
auto val = _builder.getStringAttr(s);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(std::vector<std::string> const &r) {
assert(false && "type of attribute value is not implemented");
auto val = _builder.getI32IntegerAttr(1);
return _builder.getNamedAttr(_name, val);
};
};
mlir::NamedAttribute convertNameValuePairToNamedAttribute(
std::pair<std::string, AttrValueType> nameAndVal) {
auto visitor = ONNXAttrVisitor(nameAndVal.first, builder_);
return mpark::visit(visitor, nameAndVal.second);
}
static std::pair<std::string, AttrValueType>
convertAttributeProtoToNameValuePair(onnx::AttributeProto &attr) {
AttrValueType val;
mlir::NamedAttribute convertOnnxAttributeProtoToMlirNamedAttribute(
onnx::AttributeProto &attr) {
mlir::Attribute mlirAttr;
switch (attr.type()) {
case onnx::AttributeProto::FLOAT:
return std::make_pair(attr.name(), AttrValueType(attr.f()));
mlirAttr = builder_.getF32FloatAttr(attr.f());
break;
case onnx::AttributeProto::INT:
return std::make_pair(attr.name(), AttrValueType(attr.i()));
mlirAttr = builder_.getI64IntegerAttr(attr.i());
break;
case onnx::AttributeProto::STRING:
return std::make_pair(attr.name(), AttrValueType(attr.s()));
mlirAttr = builder_.getStringAttr(attr.s());
break;
case onnx::AttributeProto::FLOATS:
val = AttrValueType(
std::vector<float>(attr.floats().begin(), attr.floats().end()));
return std::make_pair(attr.name(), val);
mlirAttr = builder_.getF32ArrayAttr(
llvm::makeArrayRef(attr.floats().begin(), attr.floats().end()));
break;
case onnx::AttributeProto::INTS:
val = AttrValueType(
std::vector<int64_t>(attr.ints().begin(), attr.ints().end()));
return std::make_pair(attr.name(), val);
mlirAttr = builder_.getI64ArrayAttr(
llvm::makeArrayRef(attr.ints().begin(), attr.ints().end()));
break;
case onnx::AttributeProto::TENSOR:
mlirAttr = onnxTensorProtoToDenseElmAttr(builder_, attr.t());
break;
default:
assert(false && "datatype for attribute is not implemented");
llvm_unreachable("datatype for attribute is not implemented");
break;
}
llvm_unreachable("Failed to convert attribute proto to name/value pair");
return builder_.getNamedAttr(attr.name(), mlirAttr);
}
std::vector<mlir::NamedAttribute> ImportNodeAttributes(
@ -217,8 +173,7 @@ private:
std::vector<mlir::NamedAttribute> attributes;
for (int i = 0; i < node.attribute_size(); ++i) {
auto attr = node.attribute(i);
auto nameValPair = convertAttributeProtoToNameValuePair(attr);
attributes.push_back(convertNameValuePairToNamedAttribute(nameValPair));
attributes.push_back(convertOnnxAttributeProtoToMlirNamedAttribute(attr));
}
return attributes;
}