Support importing tensor proto. (#116)

* Support importing tensor proto.

* Use signed types, use template.

* Resotre using signless integer types because Reshape faults with integer signs; using traits to declare/define attribute types.

* Simplify attribute importing logic.

* Use existing code to import TensorProto.

* nit.
This commit is contained in:
Tian Jin 2020-05-14 11:05:02 +08:00 committed by GitHub
parent adc08fb93e
commit 4dd3c809c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 60 additions and 97 deletions

View File

@ -126,52 +126,59 @@ bool InitializedTensorMapping::ContainKey(std::string name) {
} }
mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor( mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor(
mlir::Location loc, mlir::OpBuilder &builder, std::string name) { mlir::Location loc, mlir::OpBuilder &builder, const std::string &name) {
// Initializer for input. // Initializer for input.
onnx::TensorProto initializer = GetInitializedTensor(name); onnx::TensorProto initializer = GetInitializedTensor(name);
// Tensor dimensions.
llvm::ArrayRef<int64_t> tensorDims(
initializer.dims().data(), initializer.dims().size());
// Emit ConstantOp and record the mapping between the input and // Emit ConstantOp and record the mapping between the input and
// the constant value. // the constant value.
// Create value attribute. // Create value attribute.
mlir::DenseElementsAttr constantDenseAttribute; mlir::DenseElementsAttr denseElmAttr =
mlir::Type elementType; onnxTensorProtoToDenseElmAttr(builder, initializer);
mlir::ShapedType tensorType;
// Create ConstantOp for dense array.
return builder.create<mlir::ONNXConstantOp>(
loc, denseElmAttr.getType(), nullptr, denseElmAttr);
}
mlir::DenseElementsAttr onnxTensorProtoToDenseElmAttr(
mlir::OpBuilder &builder, const onnx::TensorProto &initializer) {
// Tensor dimensions.
llvm::ArrayRef<int64_t> tensorDims(
initializer.dims().data(), initializer.dims().size());
mlir::DenseElementsAttr denseElmAttr;
switch (initializer.data_type()) { switch (initializer.data_type()) {
case (onnx::TensorProto::FLOAT): { case (onnx::TensorProto::FLOAT): {
const auto &arrayAttrInitializer = CreateArrayAttribute<float>(initializer); const auto &arrayAttrInitializer = CreateArrayAttribute<float>(initializer);
elementType = builder.getF32Type(); auto elmType = builder.getF32Type();
tensorType = mlir::RankedTensorType::get(tensorDims, elementType); auto tensorType = mlir::RankedTensorType::get(tensorDims, elmType);
constantDenseAttribute = mlir::DenseElementsAttr::get( denseElmAttr = mlir::DenseElementsAttr::get(
tensorType, llvm::makeArrayRef(arrayAttrInitializer)); tensorType, llvm::makeArrayRef(arrayAttrInitializer));
break; break;
} }
case (onnx::TensorProto::INT32): { case (onnx::TensorProto::INT32): {
const auto &arrayAttrInitializer = const auto &arrayAttrInitializer =
CreateArrayAttribute<int32_t>(initializer); CreateArrayAttribute<int32_t>(initializer);
elementType = builder.getIntegerType(32); auto elmType = builder.getIntegerType(32);
tensorType = mlir::RankedTensorType::get(tensorDims, elementType); auto tensorType = mlir::RankedTensorType::get(tensorDims, elmType);
constantDenseAttribute = mlir::DenseElementsAttr::get( denseElmAttr = mlir::DenseElementsAttr::get(
tensorType, llvm::makeArrayRef(arrayAttrInitializer)); tensorType, llvm::makeArrayRef(arrayAttrInitializer));
break; break;
} }
case (onnx::TensorProto::INT64): { case (onnx::TensorProto::INT64): {
const auto &arrayAttrInitializer = const auto &arrayAttrInitializer =
CreateArrayAttribute<int64_t>(initializer); CreateArrayAttribute<int64_t>(initializer);
elementType = builder.getIntegerType(64); auto elmType = builder.getIntegerType(64);
tensorType = mlir::RankedTensorType::get(tensorDims, elementType); auto tensorType = mlir::RankedTensorType::get(tensorDims, elmType);
constantDenseAttribute = mlir::DenseElementsAttr::get( denseElmAttr = mlir::DenseElementsAttr::get(
tensorType, llvm::makeArrayRef(arrayAttrInitializer)); tensorType, llvm::makeArrayRef(arrayAttrInitializer));
break; break;
} }
default:
llvm_unreachable(
"Failed to import ONNX TensorProto due to unsupported data types.");
} }
return denseElmAttr;
// Create ConstantOp for dense array.
return builder.create<mlir::ONNXConstantOp>(
loc, tensorType, nullptr, constantDenseAttribute);
} }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@ -31,14 +31,12 @@
#include "llvm/ADT/ScopedHashTable.h" #include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "onnx/onnx_pb.h"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
#if INCLUDE_ONNX_ML == 1 #if INCLUDE_ONNX_ML == 1
#include "src/Dialect/MLONNX/MLONNXOps.hpp" #include "src/Dialect/MLONNX/MLONNXOps.hpp"
#endif #endif
#include "onnx/onnx_pb.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir { namespace onnx_mlir {
void replaceAll( void replaceAll(
@ -88,7 +86,7 @@ struct InitializedTensorMapping {
// argument to operations such as Reshape and will enable other // argument to operations such as Reshape and will enable other
// optimizations such as constant folding. // optimizations such as constant folding.
mlir::Value EmitInitializerForInputTensor( mlir::Value EmitInitializerForInputTensor(
mlir::Location loc, mlir::OpBuilder &builder, std::string name); mlir::Location loc, mlir::OpBuilder &builder, const std::string &name);
// Get initialized tensor. // Get initialized tensor.
onnx::TensorProto &GetInitializedTensor(std::string name) { onnx::TensorProto &GetInitializedTensor(std::string name) {
@ -103,4 +101,7 @@ private:
std::map<std::string, onnx::TensorProto> nameToInitializedTensor; std::map<std::string, onnx::TensorProto> nameToInitializedTensor;
}; };
mlir::DenseElementsAttr onnxTensorProtoToDenseElmAttr(
mlir::OpBuilder &builder, const onnx::TensorProto &initializer);
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@ -14,6 +14,7 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include <type_traits>
// Using backported variant. // Using backported variant.
// bstd = backported standard library. // bstd = backported standard library.
#include <mpark/variant.hpp> #include <mpark/variant.hpp>
@ -64,18 +65,19 @@ private:
return builder_.getF64Type(); return builder_.getF64Type();
case onnx::TensorProto_DataType::TensorProto_DataType_INT8: case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8: case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
return builder_.getIntegerType(8); return builder_.getIntegerType(/*width=*/8);
case onnx::TensorProto_DataType::TensorProto_DataType_INT16: case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16: case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
return builder_.getIntegerType(16); return builder_.getIntegerType(/*width=*/16);
case onnx::TensorProto_DataType::TensorProto_DataType_INT32: case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32: case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
return builder_.getIntegerType(32); return builder_.getIntegerType(/*width=*/32);
case onnx::TensorProto_DataType::TensorProto_DataType_INT64: case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
return builder_.getIntegerType(64); return builder_.getIntegerType(/*width=*/64);
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
return builder_.getI1Type(); return builder_.getI1Type();
case onnx::TensorProto_DataType::TensorProto_DataType_STRING: case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64: case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128: case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
@ -135,81 +137,35 @@ private:
frontend_symbols_.AddMapping(input_tensor_legalized_name, symbol); frontend_symbols_.AddMapping(input_tensor_legalized_name, symbol);
} }
typedef bstd::variant<int64_t, std::vector<int64_t>, float, mlir::NamedAttribute convertOnnxAttributeProtoToMlirNamedAttribute(
std::vector<float>, std::string, std::vector<std::string>> onnx::AttributeProto &attr) {
AttrValueType; mlir::Attribute mlirAttr;
struct ONNXAttrVisitor {
ONNXAttrVisitor(std::string name, mlir::OpBuilder &builder)
: _builder(builder), _name(std::move(name)) {}
// Op builder.
mlir::OpBuilder &_builder;
// Name of the attribute being inspected.
std::string _name;
mlir::NamedAttribute operator()(int64_t const &r) {
auto val = _builder.getI64IntegerAttr(r);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(std::vector<int64_t> const &ints) {
auto val = _builder.getI64ArrayAttr(ints);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(float const &r) {
auto val = _builder.getF32FloatAttr(r);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(std::vector<float> const &floats) {
auto val = _builder.getF32ArrayAttr(floats);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(std::string const &s) {
auto val = _builder.getStringAttr(s);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(std::vector<std::string> const &r) {
assert(false && "type of attribute value is not implemented");
auto val = _builder.getI32IntegerAttr(1);
return _builder.getNamedAttr(_name, val);
};
};
mlir::NamedAttribute convertNameValuePairToNamedAttribute(
std::pair<std::string, AttrValueType> nameAndVal) {
auto visitor = ONNXAttrVisitor(nameAndVal.first, builder_);
return mpark::visit(visitor, nameAndVal.second);
}
static std::pair<std::string, AttrValueType>
convertAttributeProtoToNameValuePair(onnx::AttributeProto &attr) {
AttrValueType val;
switch (attr.type()) { switch (attr.type()) {
case onnx::AttributeProto::FLOAT: case onnx::AttributeProto::FLOAT:
return std::make_pair(attr.name(), AttrValueType(attr.f())); mlirAttr = builder_.getF32FloatAttr(attr.f());
break;
case onnx::AttributeProto::INT: case onnx::AttributeProto::INT:
return std::make_pair(attr.name(), AttrValueType(attr.i())); mlirAttr = builder_.getI64IntegerAttr(attr.i());
break;
case onnx::AttributeProto::STRING: case onnx::AttributeProto::STRING:
return std::make_pair(attr.name(), AttrValueType(attr.s())); mlirAttr = builder_.getStringAttr(attr.s());
break;
case onnx::AttributeProto::FLOATS: case onnx::AttributeProto::FLOATS:
val = AttrValueType( mlirAttr = builder_.getF32ArrayAttr(
std::vector<float>(attr.floats().begin(), attr.floats().end())); llvm::makeArrayRef(attr.floats().begin(), attr.floats().end()));
return std::make_pair(attr.name(), val); break;
case onnx::AttributeProto::INTS: case onnx::AttributeProto::INTS:
val = AttrValueType( mlirAttr = builder_.getI64ArrayAttr(
std::vector<int64_t>(attr.ints().begin(), attr.ints().end())); llvm::makeArrayRef(attr.ints().begin(), attr.ints().end()));
return std::make_pair(attr.name(), val); break;
case onnx::AttributeProto::TENSOR:
mlirAttr = onnxTensorProtoToDenseElmAttr(builder_, attr.t());
break;
default: default:
assert(false && "datatype for attribute is not implemented"); llvm_unreachable("datatype for attribute is not implemented");
break; break;
} }
llvm_unreachable("Failed to convert attribute proto to name/value pair"); return builder_.getNamedAttr(attr.name(), mlirAttr);
} }
std::vector<mlir::NamedAttribute> ImportNodeAttributes( std::vector<mlir::NamedAttribute> ImportNodeAttributes(
@ -217,8 +173,7 @@ private:
std::vector<mlir::NamedAttribute> attributes; std::vector<mlir::NamedAttribute> attributes;
for (int i = 0; i < node.attribute_size(); ++i) { for (int i = 0; i < node.attribute_size(); ++i) {
auto attr = node.attribute(i); auto attr = node.attribute(i);
auto nameValPair = convertAttributeProtoToNameValuePair(attr); attributes.push_back(convertOnnxAttributeProtoToMlirNamedAttribute(attr));
attributes.push_back(convertNameValuePairToNamedAttribute(nameValPair));
} }
return attributes; return attributes;
} }