Fix a bug in createArrayAttribute (#43)

* Fix a bug in createArrayAttribute

* Use size_t

* Use const auto&
This commit is contained in:
Tung D. Le 2020-03-24 15:04:23 +09:00 committed by GitHub
parent b3719d486b
commit ddff0f1256
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 19 additions and 24 deletions

View File

@ -96,20 +96,22 @@ struct TransformValueToONNXData<int64_t> {
// Helper method for constructing an array attribute from a model input. // Helper method for constructing an array attribute from a model input.
template <typename T> template <typename T>
static T* CreateArrayAttribute(onnx::TensorProto initializer, int *size) { static std::vector<T> CreateArrayAttribute(onnx::TensorProto initializer) {
size_t size;
if (initializer.raw_data().size()) { if (initializer.raw_data().size()) {
// copy & take care of endianness // copy & take care of endianness
std::vector<char> byteInitializer; std::vector<char> byteInitializer;
std::copy(initializer.raw_data().begin(), initializer.raw_data().end(), std::copy(initializer.raw_data().begin(), initializer.raw_data().end(),
back_inserter(byteInitializer)); back_inserter(byteInitializer));
*size = initializer.raw_data().size() / sizeof(T); size = initializer.raw_data().size() / sizeof(T);
return reinterpret_cast<T*>(&byteInitializer[0]); T *res = reinterpret_cast<T *>(&byteInitializer[0]);
return std::vector<T>(res, res + size);
} }
// copy, no need to take care of endianness // copy, no need to take care of endianness
auto data = TransformValueToONNXData<T>::data(initializer); auto data = TransformValueToONNXData<T>::data(initializer);
*size = data.size(); size = data.size();
return &data[0]; return std::vector<T>(&data[0], &data[0] + size);
} }
void InitializedTensorMapping::AddMapping( void InitializedTensorMapping::AddMapping(
@ -139,39 +141,32 @@ mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor(
mlir::DenseElementsAttr constantDenseAttribute; mlir::DenseElementsAttr constantDenseAttribute;
mlir::Type elementType; mlir::Type elementType;
mlir::ShapedType tensorType; mlir::ShapedType tensorType;
int length;
switch (initializer.data_type()) { switch (initializer.data_type()) {
case (onnx::TensorProto::FLOAT): { case (onnx::TensorProto::FLOAT): {
float *typeArray = const auto& arrayAttrInitializer =
CreateArrayAttribute<float>(initializer, &length); CreateArrayAttribute<float>(initializer);
std::vector<float> arrayAttrInitializer(
typeArray, typeArray + length);
llvm::ArrayRef<float> array(typeArray, length);
elementType = builder.getF32Type(); elementType = builder.getF32Type();
tensorType = mlir::RankedTensorType::get(tensorDims, elementType); tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
constantDenseAttribute = mlir::DenseElementsAttr::get(tensorType, array); constantDenseAttribute = mlir::DenseElementsAttr::get(
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
break; break;
} }
case (onnx::TensorProto::INT32): { case (onnx::TensorProto::INT32): {
int32_t *typeArray = const auto& arrayAttrInitializer =
CreateArrayAttribute<int32_t>(initializer, &length); CreateArrayAttribute<int32_t>(initializer);
std::vector<int32_t> arrayAttrInitializer(
typeArray, typeArray + length);
llvm::ArrayRef<int32_t> array(typeArray, length);
elementType = builder.getIntegerType(32); elementType = builder.getIntegerType(32);
tensorType = mlir::RankedTensorType::get(tensorDims, elementType); tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
constantDenseAttribute = mlir::DenseElementsAttr::get(tensorType, array); constantDenseAttribute = mlir::DenseElementsAttr::get(
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
break; break;
} }
case (onnx::TensorProto::INT64): { case (onnx::TensorProto::INT64): {
int64_t *typeArray = const auto& arrayAttrInitializer =
CreateArrayAttribute<int64_t>(initializer, &length); CreateArrayAttribute<int64_t>(initializer);
std::vector<int64_t> arrayAttrInitializer(
typeArray, typeArray + length);
llvm::ArrayRef<int64_t> array(typeArray, length);
elementType = builder.getIntegerType(64); elementType = builder.getIntegerType(64);
tensorType = mlir::RankedTensorType::get(tensorDims, elementType); tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
constantDenseAttribute = mlir::DenseElementsAttr::get(tensorType, array); constantDenseAttribute = mlir::DenseElementsAttr::get(
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
break; break;
} }
} }