Support importing tensor proto. (#116)
* Support importing tensor proto. * Use signed types, use template. * Resotre using signless integer types because Reshape faults with integer signs; using traits to declare/define attribute types. * Simplify attribute importing logic. * Use existing code to import TensorProto. * nit.
This commit is contained in:
parent
adc08fb93e
commit
4dd3c809c7
|
@ -126,52 +126,59 @@ bool InitializedTensorMapping::ContainKey(std::string name) {
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor(
|
mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor(
|
||||||
mlir::Location loc, mlir::OpBuilder &builder, std::string name) {
|
mlir::Location loc, mlir::OpBuilder &builder, const std::string &name) {
|
||||||
// Initializer for input.
|
// Initializer for input.
|
||||||
onnx::TensorProto initializer = GetInitializedTensor(name);
|
onnx::TensorProto initializer = GetInitializedTensor(name);
|
||||||
|
|
||||||
// Tensor dimensions.
|
|
||||||
llvm::ArrayRef<int64_t> tensorDims(
|
|
||||||
initializer.dims().data(), initializer.dims().size());
|
|
||||||
|
|
||||||
// Emit ConstantOp and record the mapping between the input and
|
// Emit ConstantOp and record the mapping between the input and
|
||||||
// the constant value.
|
// the constant value.
|
||||||
// Create value attribute.
|
// Create value attribute.
|
||||||
mlir::DenseElementsAttr constantDenseAttribute;
|
mlir::DenseElementsAttr denseElmAttr =
|
||||||
mlir::Type elementType;
|
onnxTensorProtoToDenseElmAttr(builder, initializer);
|
||||||
mlir::ShapedType tensorType;
|
|
||||||
|
// Create ConstantOp for dense array.
|
||||||
|
return builder.create<mlir::ONNXConstantOp>(
|
||||||
|
loc, denseElmAttr.getType(), nullptr, denseElmAttr);
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::DenseElementsAttr onnxTensorProtoToDenseElmAttr(
|
||||||
|
mlir::OpBuilder &builder, const onnx::TensorProto &initializer) {
|
||||||
|
// Tensor dimensions.
|
||||||
|
llvm::ArrayRef<int64_t> tensorDims(
|
||||||
|
initializer.dims().data(), initializer.dims().size());
|
||||||
|
mlir::DenseElementsAttr denseElmAttr;
|
||||||
switch (initializer.data_type()) {
|
switch (initializer.data_type()) {
|
||||||
case (onnx::TensorProto::FLOAT): {
|
case (onnx::TensorProto::FLOAT): {
|
||||||
const auto &arrayAttrInitializer = CreateArrayAttribute<float>(initializer);
|
const auto &arrayAttrInitializer = CreateArrayAttribute<float>(initializer);
|
||||||
elementType = builder.getF32Type();
|
auto elmType = builder.getF32Type();
|
||||||
tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
|
auto tensorType = mlir::RankedTensorType::get(tensorDims, elmType);
|
||||||
constantDenseAttribute = mlir::DenseElementsAttr::get(
|
denseElmAttr = mlir::DenseElementsAttr::get(
|
||||||
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
|
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case (onnx::TensorProto::INT32): {
|
case (onnx::TensorProto::INT32): {
|
||||||
const auto &arrayAttrInitializer =
|
const auto &arrayAttrInitializer =
|
||||||
CreateArrayAttribute<int32_t>(initializer);
|
CreateArrayAttribute<int32_t>(initializer);
|
||||||
elementType = builder.getIntegerType(32);
|
auto elmType = builder.getIntegerType(32);
|
||||||
tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
|
auto tensorType = mlir::RankedTensorType::get(tensorDims, elmType);
|
||||||
constantDenseAttribute = mlir::DenseElementsAttr::get(
|
denseElmAttr = mlir::DenseElementsAttr::get(
|
||||||
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
|
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case (onnx::TensorProto::INT64): {
|
case (onnx::TensorProto::INT64): {
|
||||||
const auto &arrayAttrInitializer =
|
const auto &arrayAttrInitializer =
|
||||||
CreateArrayAttribute<int64_t>(initializer);
|
CreateArrayAttribute<int64_t>(initializer);
|
||||||
elementType = builder.getIntegerType(64);
|
auto elmType = builder.getIntegerType(64);
|
||||||
tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
|
auto tensorType = mlir::RankedTensorType::get(tensorDims, elmType);
|
||||||
constantDenseAttribute = mlir::DenseElementsAttr::get(
|
denseElmAttr = mlir::DenseElementsAttr::get(
|
||||||
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
|
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
llvm_unreachable(
|
||||||
|
"Failed to import ONNX TensorProto due to unsupported data types.");
|
||||||
}
|
}
|
||||||
|
return denseElmAttr;
|
||||||
// Create ConstantOp for dense array.
|
|
||||||
return builder.create<mlir::ONNXConstantOp>(
|
|
||||||
loc, tensorType, nullptr, constantDenseAttribute);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|
|
@ -31,14 +31,12 @@
|
||||||
#include "llvm/ADT/ScopedHashTable.h"
|
#include "llvm/ADT/ScopedHashTable.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
#include "onnx/onnx_pb.h"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
#if INCLUDE_ONNX_ML == 1
|
#if INCLUDE_ONNX_ML == 1
|
||||||
#include "src/Dialect/MLONNX/MLONNXOps.hpp"
|
#include "src/Dialect/MLONNX/MLONNXOps.hpp"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "onnx/onnx_pb.h"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
void replaceAll(
|
void replaceAll(
|
||||||
|
@ -88,7 +86,7 @@ struct InitializedTensorMapping {
|
||||||
// argument to operations such as Reshape and will enable other
|
// argument to operations such as Reshape and will enable other
|
||||||
// optimizations such as constant folding.
|
// optimizations such as constant folding.
|
||||||
mlir::Value EmitInitializerForInputTensor(
|
mlir::Value EmitInitializerForInputTensor(
|
||||||
mlir::Location loc, mlir::OpBuilder &builder, std::string name);
|
mlir::Location loc, mlir::OpBuilder &builder, const std::string &name);
|
||||||
|
|
||||||
// Get initialized tensor.
|
// Get initialized tensor.
|
||||||
onnx::TensorProto &GetInitializedTensor(std::string name) {
|
onnx::TensorProto &GetInitializedTensor(std::string name) {
|
||||||
|
@ -103,4 +101,7 @@ private:
|
||||||
std::map<std::string, onnx::TensorProto> nameToInitializedTensor;
|
std::map<std::string, onnx::TensorProto> nameToInitializedTensor;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
mlir::DenseElementsAttr onnxTensorProtoToDenseElmAttr(
|
||||||
|
mlir::OpBuilder &builder, const onnx::TensorProto &initializer);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include <type_traits>
|
||||||
// Using backported variant.
|
// Using backported variant.
|
||||||
// bstd = backported standard library.
|
// bstd = backported standard library.
|
||||||
#include <mpark/variant.hpp>
|
#include <mpark/variant.hpp>
|
||||||
|
@ -64,18 +65,19 @@ private:
|
||||||
return builder_.getF64Type();
|
return builder_.getF64Type();
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
|
case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
|
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
|
||||||
return builder_.getIntegerType(8);
|
return builder_.getIntegerType(/*width=*/8);
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
|
case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
|
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
|
||||||
return builder_.getIntegerType(16);
|
return builder_.getIntegerType(/*width=*/16);
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
|
case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
|
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
|
||||||
return builder_.getIntegerType(32);
|
return builder_.getIntegerType(/*width=*/32);
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
|
case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
|
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
|
||||||
return builder_.getIntegerType(64);
|
return builder_.getIntegerType(/*width=*/64);
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
|
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
|
||||||
return builder_.getI1Type();
|
return builder_.getI1Type();
|
||||||
|
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
|
case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
|
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
|
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
|
||||||
|
@ -135,81 +137,35 @@ private:
|
||||||
frontend_symbols_.AddMapping(input_tensor_legalized_name, symbol);
|
frontend_symbols_.AddMapping(input_tensor_legalized_name, symbol);
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef bstd::variant<int64_t, std::vector<int64_t>, float,
|
mlir::NamedAttribute convertOnnxAttributeProtoToMlirNamedAttribute(
|
||||||
std::vector<float>, std::string, std::vector<std::string>>
|
onnx::AttributeProto &attr) {
|
||||||
AttrValueType;
|
mlir::Attribute mlirAttr;
|
||||||
|
|
||||||
struct ONNXAttrVisitor {
|
|
||||||
ONNXAttrVisitor(std::string name, mlir::OpBuilder &builder)
|
|
||||||
: _builder(builder), _name(std::move(name)) {}
|
|
||||||
|
|
||||||
// Op builder.
|
|
||||||
mlir::OpBuilder &_builder;
|
|
||||||
|
|
||||||
// Name of the attribute being inspected.
|
|
||||||
std::string _name;
|
|
||||||
|
|
||||||
mlir::NamedAttribute operator()(int64_t const &r) {
|
|
||||||
auto val = _builder.getI64IntegerAttr(r);
|
|
||||||
return _builder.getNamedAttr(_name, val);
|
|
||||||
}
|
|
||||||
|
|
||||||
mlir::NamedAttribute operator()(std::vector<int64_t> const &ints) {
|
|
||||||
auto val = _builder.getI64ArrayAttr(ints);
|
|
||||||
return _builder.getNamedAttr(_name, val);
|
|
||||||
}
|
|
||||||
|
|
||||||
mlir::NamedAttribute operator()(float const &r) {
|
|
||||||
auto val = _builder.getF32FloatAttr(r);
|
|
||||||
return _builder.getNamedAttr(_name, val);
|
|
||||||
}
|
|
||||||
|
|
||||||
mlir::NamedAttribute operator()(std::vector<float> const &floats) {
|
|
||||||
auto val = _builder.getF32ArrayAttr(floats);
|
|
||||||
return _builder.getNamedAttr(_name, val);
|
|
||||||
}
|
|
||||||
|
|
||||||
mlir::NamedAttribute operator()(std::string const &s) {
|
|
||||||
auto val = _builder.getStringAttr(s);
|
|
||||||
return _builder.getNamedAttr(_name, val);
|
|
||||||
}
|
|
||||||
|
|
||||||
mlir::NamedAttribute operator()(std::vector<std::string> const &r) {
|
|
||||||
assert(false && "type of attribute value is not implemented");
|
|
||||||
auto val = _builder.getI32IntegerAttr(1);
|
|
||||||
return _builder.getNamedAttr(_name, val);
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
mlir::NamedAttribute convertNameValuePairToNamedAttribute(
|
|
||||||
std::pair<std::string, AttrValueType> nameAndVal) {
|
|
||||||
auto visitor = ONNXAttrVisitor(nameAndVal.first, builder_);
|
|
||||||
return mpark::visit(visitor, nameAndVal.second);
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::pair<std::string, AttrValueType>
|
|
||||||
convertAttributeProtoToNameValuePair(onnx::AttributeProto &attr) {
|
|
||||||
AttrValueType val;
|
|
||||||
switch (attr.type()) {
|
switch (attr.type()) {
|
||||||
case onnx::AttributeProto::FLOAT:
|
case onnx::AttributeProto::FLOAT:
|
||||||
return std::make_pair(attr.name(), AttrValueType(attr.f()));
|
mlirAttr = builder_.getF32FloatAttr(attr.f());
|
||||||
|
break;
|
||||||
case onnx::AttributeProto::INT:
|
case onnx::AttributeProto::INT:
|
||||||
return std::make_pair(attr.name(), AttrValueType(attr.i()));
|
mlirAttr = builder_.getI64IntegerAttr(attr.i());
|
||||||
|
break;
|
||||||
case onnx::AttributeProto::STRING:
|
case onnx::AttributeProto::STRING:
|
||||||
return std::make_pair(attr.name(), AttrValueType(attr.s()));
|
mlirAttr = builder_.getStringAttr(attr.s());
|
||||||
|
break;
|
||||||
case onnx::AttributeProto::FLOATS:
|
case onnx::AttributeProto::FLOATS:
|
||||||
val = AttrValueType(
|
mlirAttr = builder_.getF32ArrayAttr(
|
||||||
std::vector<float>(attr.floats().begin(), attr.floats().end()));
|
llvm::makeArrayRef(attr.floats().begin(), attr.floats().end()));
|
||||||
return std::make_pair(attr.name(), val);
|
break;
|
||||||
case onnx::AttributeProto::INTS:
|
case onnx::AttributeProto::INTS:
|
||||||
val = AttrValueType(
|
mlirAttr = builder_.getI64ArrayAttr(
|
||||||
std::vector<int64_t>(attr.ints().begin(), attr.ints().end()));
|
llvm::makeArrayRef(attr.ints().begin(), attr.ints().end()));
|
||||||
return std::make_pair(attr.name(), val);
|
break;
|
||||||
|
case onnx::AttributeProto::TENSOR:
|
||||||
|
mlirAttr = onnxTensorProtoToDenseElmAttr(builder_, attr.t());
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
assert(false && "datatype for attribute is not implemented");
|
llvm_unreachable("datatype for attribute is not implemented");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
llvm_unreachable("Failed to convert attribute proto to name/value pair");
|
return builder_.getNamedAttr(attr.name(), mlirAttr);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<mlir::NamedAttribute> ImportNodeAttributes(
|
std::vector<mlir::NamedAttribute> ImportNodeAttributes(
|
||||||
|
@ -217,8 +173,7 @@ private:
|
||||||
std::vector<mlir::NamedAttribute> attributes;
|
std::vector<mlir::NamedAttribute> attributes;
|
||||||
for (int i = 0; i < node.attribute_size(); ++i) {
|
for (int i = 0; i < node.attribute_size(); ++i) {
|
||||||
auto attr = node.attribute(i);
|
auto attr = node.attribute(i);
|
||||||
auto nameValPair = convertAttributeProtoToNameValuePair(attr);
|
attributes.push_back(convertOnnxAttributeProtoToMlirNamedAttribute(attr));
|
||||||
attributes.push_back(convertNameValuePairToNamedAttribute(nameValPair));
|
|
||||||
}
|
}
|
||||||
return attributes;
|
return attributes;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue