Support importing tensor proto. (#116)
* Support importing tensor proto. * Use signed types, use template. * Resotre using signless integer types because Reshape faults with integer signs; using traits to declare/define attribute types. * Simplify attribute importing logic. * Use existing code to import TensorProto. * nit.
This commit is contained in:
parent
adc08fb93e
commit
4dd3c809c7
|
@ -126,52 +126,59 @@ bool InitializedTensorMapping::ContainKey(std::string name) {
|
|||
}
|
||||
|
||||
mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor(
|
||||
mlir::Location loc, mlir::OpBuilder &builder, std::string name) {
|
||||
mlir::Location loc, mlir::OpBuilder &builder, const std::string &name) {
|
||||
// Initializer for input.
|
||||
onnx::TensorProto initializer = GetInitializedTensor(name);
|
||||
|
||||
// Tensor dimensions.
|
||||
llvm::ArrayRef<int64_t> tensorDims(
|
||||
initializer.dims().data(), initializer.dims().size());
|
||||
|
||||
// Emit ConstantOp and record the mapping between the input and
|
||||
// the constant value.
|
||||
// Create value attribute.
|
||||
mlir::DenseElementsAttr constantDenseAttribute;
|
||||
mlir::Type elementType;
|
||||
mlir::ShapedType tensorType;
|
||||
mlir::DenseElementsAttr denseElmAttr =
|
||||
onnxTensorProtoToDenseElmAttr(builder, initializer);
|
||||
|
||||
// Create ConstantOp for dense array.
|
||||
return builder.create<mlir::ONNXConstantOp>(
|
||||
loc, denseElmAttr.getType(), nullptr, denseElmAttr);
|
||||
}
|
||||
|
||||
mlir::DenseElementsAttr onnxTensorProtoToDenseElmAttr(
|
||||
mlir::OpBuilder &builder, const onnx::TensorProto &initializer) {
|
||||
// Tensor dimensions.
|
||||
llvm::ArrayRef<int64_t> tensorDims(
|
||||
initializer.dims().data(), initializer.dims().size());
|
||||
mlir::DenseElementsAttr denseElmAttr;
|
||||
switch (initializer.data_type()) {
|
||||
case (onnx::TensorProto::FLOAT): {
|
||||
const auto &arrayAttrInitializer = CreateArrayAttribute<float>(initializer);
|
||||
elementType = builder.getF32Type();
|
||||
tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
|
||||
constantDenseAttribute = mlir::DenseElementsAttr::get(
|
||||
auto elmType = builder.getF32Type();
|
||||
auto tensorType = mlir::RankedTensorType::get(tensorDims, elmType);
|
||||
denseElmAttr = mlir::DenseElementsAttr::get(
|
||||
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
|
||||
break;
|
||||
}
|
||||
case (onnx::TensorProto::INT32): {
|
||||
const auto &arrayAttrInitializer =
|
||||
CreateArrayAttribute<int32_t>(initializer);
|
||||
elementType = builder.getIntegerType(32);
|
||||
tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
|
||||
constantDenseAttribute = mlir::DenseElementsAttr::get(
|
||||
auto elmType = builder.getIntegerType(32);
|
||||
auto tensorType = mlir::RankedTensorType::get(tensorDims, elmType);
|
||||
denseElmAttr = mlir::DenseElementsAttr::get(
|
||||
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
|
||||
break;
|
||||
}
|
||||
case (onnx::TensorProto::INT64): {
|
||||
const auto &arrayAttrInitializer =
|
||||
CreateArrayAttribute<int64_t>(initializer);
|
||||
elementType = builder.getIntegerType(64);
|
||||
tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
|
||||
constantDenseAttribute = mlir::DenseElementsAttr::get(
|
||||
auto elmType = builder.getIntegerType(64);
|
||||
auto tensorType = mlir::RankedTensorType::get(tensorDims, elmType);
|
||||
denseElmAttr = mlir::DenseElementsAttr::get(
|
||||
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
llvm_unreachable(
|
||||
"Failed to import ONNX TensorProto due to unsupported data types.");
|
||||
}
|
||||
|
||||
// Create ConstantOp for dense array.
|
||||
return builder.create<mlir::ONNXConstantOp>(
|
||||
loc, tensorType, nullptr, constantDenseAttribute);
|
||||
return denseElmAttr;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
|
|
@ -31,14 +31,12 @@
|
|||
#include "llvm/ADT/ScopedHashTable.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "onnx/onnx_pb.h"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
#if INCLUDE_ONNX_ML == 1
|
||||
#include "src/Dialect/MLONNX/MLONNXOps.hpp"
|
||||
#endif
|
||||
|
||||
#include "onnx/onnx_pb.h"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void replaceAll(
|
||||
|
@ -88,7 +86,7 @@ struct InitializedTensorMapping {
|
|||
// argument to operations such as Reshape and will enable other
|
||||
// optimizations such as constant folding.
|
||||
mlir::Value EmitInitializerForInputTensor(
|
||||
mlir::Location loc, mlir::OpBuilder &builder, std::string name);
|
||||
mlir::Location loc, mlir::OpBuilder &builder, const std::string &name);
|
||||
|
||||
// Get initialized tensor.
|
||||
onnx::TensorProto &GetInitializedTensor(std::string name) {
|
||||
|
@ -103,4 +101,7 @@ private:
|
|||
std::map<std::string, onnx::TensorProto> nameToInitializedTensor;
|
||||
};
|
||||
|
||||
mlir::DenseElementsAttr onnxTensorProtoToDenseElmAttr(
|
||||
mlir::OpBuilder &builder, const onnx::TensorProto &initializer);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <type_traits>
|
||||
// Using backported variant.
|
||||
// bstd = backported standard library.
|
||||
#include <mpark/variant.hpp>
|
||||
|
@ -64,18 +65,19 @@ private:
|
|||
return builder_.getF64Type();
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
|
||||
return builder_.getIntegerType(8);
|
||||
return builder_.getIntegerType(/*width=*/8);
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
|
||||
return builder_.getIntegerType(16);
|
||||
return builder_.getIntegerType(/*width=*/16);
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
|
||||
return builder_.getIntegerType(32);
|
||||
return builder_.getIntegerType(/*width=*/32);
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
|
||||
return builder_.getIntegerType(64);
|
||||
return builder_.getIntegerType(/*width=*/64);
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
|
||||
return builder_.getI1Type();
|
||||
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
|
||||
|
@ -135,81 +137,35 @@ private:
|
|||
frontend_symbols_.AddMapping(input_tensor_legalized_name, symbol);
|
||||
}
|
||||
|
||||
typedef bstd::variant<int64_t, std::vector<int64_t>, float,
|
||||
std::vector<float>, std::string, std::vector<std::string>>
|
||||
AttrValueType;
|
||||
|
||||
struct ONNXAttrVisitor {
|
||||
ONNXAttrVisitor(std::string name, mlir::OpBuilder &builder)
|
||||
: _builder(builder), _name(std::move(name)) {}
|
||||
|
||||
// Op builder.
|
||||
mlir::OpBuilder &_builder;
|
||||
|
||||
// Name of the attribute being inspected.
|
||||
std::string _name;
|
||||
|
||||
mlir::NamedAttribute operator()(int64_t const &r) {
|
||||
auto val = _builder.getI64IntegerAttr(r);
|
||||
return _builder.getNamedAttr(_name, val);
|
||||
}
|
||||
|
||||
mlir::NamedAttribute operator()(std::vector<int64_t> const &ints) {
|
||||
auto val = _builder.getI64ArrayAttr(ints);
|
||||
return _builder.getNamedAttr(_name, val);
|
||||
}
|
||||
|
||||
mlir::NamedAttribute operator()(float const &r) {
|
||||
auto val = _builder.getF32FloatAttr(r);
|
||||
return _builder.getNamedAttr(_name, val);
|
||||
}
|
||||
|
||||
mlir::NamedAttribute operator()(std::vector<float> const &floats) {
|
||||
auto val = _builder.getF32ArrayAttr(floats);
|
||||
return _builder.getNamedAttr(_name, val);
|
||||
}
|
||||
|
||||
mlir::NamedAttribute operator()(std::string const &s) {
|
||||
auto val = _builder.getStringAttr(s);
|
||||
return _builder.getNamedAttr(_name, val);
|
||||
}
|
||||
|
||||
mlir::NamedAttribute operator()(std::vector<std::string> const &r) {
|
||||
assert(false && "type of attribute value is not implemented");
|
||||
auto val = _builder.getI32IntegerAttr(1);
|
||||
return _builder.getNamedAttr(_name, val);
|
||||
};
|
||||
};
|
||||
|
||||
mlir::NamedAttribute convertNameValuePairToNamedAttribute(
|
||||
std::pair<std::string, AttrValueType> nameAndVal) {
|
||||
auto visitor = ONNXAttrVisitor(nameAndVal.first, builder_);
|
||||
return mpark::visit(visitor, nameAndVal.second);
|
||||
}
|
||||
|
||||
static std::pair<std::string, AttrValueType>
|
||||
convertAttributeProtoToNameValuePair(onnx::AttributeProto &attr) {
|
||||
AttrValueType val;
|
||||
mlir::NamedAttribute convertOnnxAttributeProtoToMlirNamedAttribute(
|
||||
onnx::AttributeProto &attr) {
|
||||
mlir::Attribute mlirAttr;
|
||||
switch (attr.type()) {
|
||||
case onnx::AttributeProto::FLOAT:
|
||||
return std::make_pair(attr.name(), AttrValueType(attr.f()));
|
||||
mlirAttr = builder_.getF32FloatAttr(attr.f());
|
||||
break;
|
||||
case onnx::AttributeProto::INT:
|
||||
return std::make_pair(attr.name(), AttrValueType(attr.i()));
|
||||
mlirAttr = builder_.getI64IntegerAttr(attr.i());
|
||||
break;
|
||||
case onnx::AttributeProto::STRING:
|
||||
return std::make_pair(attr.name(), AttrValueType(attr.s()));
|
||||
mlirAttr = builder_.getStringAttr(attr.s());
|
||||
break;
|
||||
case onnx::AttributeProto::FLOATS:
|
||||
val = AttrValueType(
|
||||
std::vector<float>(attr.floats().begin(), attr.floats().end()));
|
||||
return std::make_pair(attr.name(), val);
|
||||
mlirAttr = builder_.getF32ArrayAttr(
|
||||
llvm::makeArrayRef(attr.floats().begin(), attr.floats().end()));
|
||||
break;
|
||||
case onnx::AttributeProto::INTS:
|
||||
val = AttrValueType(
|
||||
std::vector<int64_t>(attr.ints().begin(), attr.ints().end()));
|
||||
return std::make_pair(attr.name(), val);
|
||||
mlirAttr = builder_.getI64ArrayAttr(
|
||||
llvm::makeArrayRef(attr.ints().begin(), attr.ints().end()));
|
||||
break;
|
||||
case onnx::AttributeProto::TENSOR:
|
||||
mlirAttr = onnxTensorProtoToDenseElmAttr(builder_, attr.t());
|
||||
break;
|
||||
default:
|
||||
assert(false && "datatype for attribute is not implemented");
|
||||
llvm_unreachable("datatype for attribute is not implemented");
|
||||
break;
|
||||
}
|
||||
llvm_unreachable("Failed to convert attribute proto to name/value pair");
|
||||
return builder_.getNamedAttr(attr.name(), mlirAttr);
|
||||
}
|
||||
|
||||
std::vector<mlir::NamedAttribute> ImportNodeAttributes(
|
||||
|
@ -217,8 +173,7 @@ private:
|
|||
std::vector<mlir::NamedAttribute> attributes;
|
||||
for (int i = 0; i < node.attribute_size(); ++i) {
|
||||
auto attr = node.attribute(i);
|
||||
auto nameValPair = convertAttributeProtoToNameValuePair(attr);
|
||||
attributes.push_back(convertNameValuePairToNamedAttribute(nameValPair));
|
||||
attributes.push_back(convertOnnxAttributeProtoToMlirNamedAttribute(attr));
|
||||
}
|
||||
return attributes;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue