From 4dd3c809c7bbdb3ef1e6bc86b267600d9e0c3867 Mon Sep 17 00:00:00 2001 From: Tian Jin Date: Thu, 14 May 2020 11:05:02 +0800 Subject: [PATCH] Support importing tensor proto. (#116) * Support importing tensor proto. * Use signed types, use template. * Resotre using signless integer types because Reshape faults with integer signs; using traits to declare/define attribute types. * Simplify attribute importing logic. * Use existing code to import TensorProto. * nit. --- src/Builder/FrontendDialectHelper.cpp | 49 ++++++----- src/Builder/FrontendDialectHelper.hpp | 9 +- src/Builder/FrontendDialectTransformer.cpp | 99 ++++++---------------- 3 files changed, 60 insertions(+), 97 deletions(-) diff --git a/src/Builder/FrontendDialectHelper.cpp b/src/Builder/FrontendDialectHelper.cpp index 49bce5a..63bbde6 100644 --- a/src/Builder/FrontendDialectHelper.cpp +++ b/src/Builder/FrontendDialectHelper.cpp @@ -126,52 +126,59 @@ bool InitializedTensorMapping::ContainKey(std::string name) { } mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor( - mlir::Location loc, mlir::OpBuilder &builder, std::string name) { + mlir::Location loc, mlir::OpBuilder &builder, const std::string &name) { // Initializer for input. onnx::TensorProto initializer = GetInitializedTensor(name); - // Tensor dimensions. - llvm::ArrayRef tensorDims( - initializer.dims().data(), initializer.dims().size()); - // Emit ConstantOp and record the mapping between the input and // the constant value. // Create value attribute. - mlir::DenseElementsAttr constantDenseAttribute; - mlir::Type elementType; - mlir::ShapedType tensorType; + mlir::DenseElementsAttr denseElmAttr = + onnxTensorProtoToDenseElmAttr(builder, initializer); + + // Create ConstantOp for dense array. + return builder.create( + loc, denseElmAttr.getType(), nullptr, denseElmAttr); +} + +mlir::DenseElementsAttr onnxTensorProtoToDenseElmAttr( + mlir::OpBuilder &builder, const onnx::TensorProto &initializer) { + // Tensor dimensions. + llvm::ArrayRef tensorDims( + initializer.dims().data(), initializer.dims().size()); + mlir::DenseElementsAttr denseElmAttr; switch (initializer.data_type()) { case (onnx::TensorProto::FLOAT): { const auto &arrayAttrInitializer = CreateArrayAttribute(initializer); - elementType = builder.getF32Type(); - tensorType = mlir::RankedTensorType::get(tensorDims, elementType); - constantDenseAttribute = mlir::DenseElementsAttr::get( + auto elmType = builder.getF32Type(); + auto tensorType = mlir::RankedTensorType::get(tensorDims, elmType); + denseElmAttr = mlir::DenseElementsAttr::get( tensorType, llvm::makeArrayRef(arrayAttrInitializer)); break; } case (onnx::TensorProto::INT32): { const auto &arrayAttrInitializer = CreateArrayAttribute(initializer); - elementType = builder.getIntegerType(32); - tensorType = mlir::RankedTensorType::get(tensorDims, elementType); - constantDenseAttribute = mlir::DenseElementsAttr::get( + auto elmType = builder.getIntegerType(32); + auto tensorType = mlir::RankedTensorType::get(tensorDims, elmType); + denseElmAttr = mlir::DenseElementsAttr::get( tensorType, llvm::makeArrayRef(arrayAttrInitializer)); break; } case (onnx::TensorProto::INT64): { const auto &arrayAttrInitializer = CreateArrayAttribute(initializer); - elementType = builder.getIntegerType(64); - tensorType = mlir::RankedTensorType::get(tensorDims, elementType); - constantDenseAttribute = mlir::DenseElementsAttr::get( + auto elmType = builder.getIntegerType(64); + auto tensorType = mlir::RankedTensorType::get(tensorDims, elmType); + denseElmAttr = mlir::DenseElementsAttr::get( tensorType, llvm::makeArrayRef(arrayAttrInitializer)); break; } + default: + llvm_unreachable( + "Failed to import ONNX TensorProto due to unsupported data types."); } - - // Create ConstantOp for dense array. - return builder.create( - loc, tensorType, nullptr, constantDenseAttribute); + return denseElmAttr; } } // namespace onnx_mlir diff --git a/src/Builder/FrontendDialectHelper.hpp b/src/Builder/FrontendDialectHelper.hpp index 2f329da..584b2f3 100644 --- a/src/Builder/FrontendDialectHelper.hpp +++ b/src/Builder/FrontendDialectHelper.hpp @@ -31,14 +31,12 @@ #include "llvm/ADT/ScopedHashTable.h" #include "llvm/Support/raw_ostream.h" +#include "onnx/onnx_pb.h" #include "src/Dialect/ONNX/ONNXOps.hpp" #if INCLUDE_ONNX_ML == 1 #include "src/Dialect/MLONNX/MLONNXOps.hpp" #endif -#include "onnx/onnx_pb.h" -#include "src/Dialect/ONNX/ONNXOps.hpp" - namespace onnx_mlir { void replaceAll( @@ -88,7 +86,7 @@ struct InitializedTensorMapping { // argument to operations such as Reshape and will enable other // optimizations such as constant folding. mlir::Value EmitInitializerForInputTensor( - mlir::Location loc, mlir::OpBuilder &builder, std::string name); + mlir::Location loc, mlir::OpBuilder &builder, const std::string &name); // Get initialized tensor. onnx::TensorProto &GetInitializedTensor(std::string name) { @@ -103,4 +101,7 @@ private: std::map nameToInitializedTensor; }; +mlir::DenseElementsAttr onnxTensorProtoToDenseElmAttr( + mlir::OpBuilder &builder, const onnx::TensorProto &initializer); + } // namespace onnx_mlir diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index c7af1ea..068b222 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -14,6 +14,7 @@ // //===----------------------------------------------------------------------===// +#include // Using backported variant. // bstd = backported standard library. #include @@ -64,18 +65,19 @@ private: return builder_.getF64Type(); case onnx::TensorProto_DataType::TensorProto_DataType_INT8: case onnx::TensorProto_DataType::TensorProto_DataType_UINT8: - return builder_.getIntegerType(8); + return builder_.getIntegerType(/*width=*/8); case onnx::TensorProto_DataType::TensorProto_DataType_INT16: case onnx::TensorProto_DataType::TensorProto_DataType_UINT16: - return builder_.getIntegerType(16); + return builder_.getIntegerType(/*width=*/16); case onnx::TensorProto_DataType::TensorProto_DataType_INT32: case onnx::TensorProto_DataType::TensorProto_DataType_UINT32: - return builder_.getIntegerType(32); + return builder_.getIntegerType(/*width=*/32); case onnx::TensorProto_DataType::TensorProto_DataType_INT64: case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: - return builder_.getIntegerType(64); + return builder_.getIntegerType(/*width=*/64); case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: return builder_.getI1Type(); + case onnx::TensorProto_DataType::TensorProto_DataType_STRING: case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64: case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128: @@ -135,81 +137,35 @@ private: frontend_symbols_.AddMapping(input_tensor_legalized_name, symbol); } - typedef bstd::variant, float, - std::vector, std::string, std::vector> - AttrValueType; - - struct ONNXAttrVisitor { - ONNXAttrVisitor(std::string name, mlir::OpBuilder &builder) - : _builder(builder), _name(std::move(name)) {} - - // Op builder. - mlir::OpBuilder &_builder; - - // Name of the attribute being inspected. - std::string _name; - - mlir::NamedAttribute operator()(int64_t const &r) { - auto val = _builder.getI64IntegerAttr(r); - return _builder.getNamedAttr(_name, val); - } - - mlir::NamedAttribute operator()(std::vector const &ints) { - auto val = _builder.getI64ArrayAttr(ints); - return _builder.getNamedAttr(_name, val); - } - - mlir::NamedAttribute operator()(float const &r) { - auto val = _builder.getF32FloatAttr(r); - return _builder.getNamedAttr(_name, val); - } - - mlir::NamedAttribute operator()(std::vector const &floats) { - auto val = _builder.getF32ArrayAttr(floats); - return _builder.getNamedAttr(_name, val); - } - - mlir::NamedAttribute operator()(std::string const &s) { - auto val = _builder.getStringAttr(s); - return _builder.getNamedAttr(_name, val); - } - - mlir::NamedAttribute operator()(std::vector const &r) { - assert(false && "type of attribute value is not implemented"); - auto val = _builder.getI32IntegerAttr(1); - return _builder.getNamedAttr(_name, val); - }; - }; - - mlir::NamedAttribute convertNameValuePairToNamedAttribute( - std::pair nameAndVal) { - auto visitor = ONNXAttrVisitor(nameAndVal.first, builder_); - return mpark::visit(visitor, nameAndVal.second); - } - - static std::pair - convertAttributeProtoToNameValuePair(onnx::AttributeProto &attr) { - AttrValueType val; + mlir::NamedAttribute convertOnnxAttributeProtoToMlirNamedAttribute( + onnx::AttributeProto &attr) { + mlir::Attribute mlirAttr; switch (attr.type()) { case onnx::AttributeProto::FLOAT: - return std::make_pair(attr.name(), AttrValueType(attr.f())); + mlirAttr = builder_.getF32FloatAttr(attr.f()); + break; case onnx::AttributeProto::INT: - return std::make_pair(attr.name(), AttrValueType(attr.i())); + mlirAttr = builder_.getI64IntegerAttr(attr.i()); + break; case onnx::AttributeProto::STRING: - return std::make_pair(attr.name(), AttrValueType(attr.s())); + mlirAttr = builder_.getStringAttr(attr.s()); + break; case onnx::AttributeProto::FLOATS: - val = AttrValueType( - std::vector(attr.floats().begin(), attr.floats().end())); - return std::make_pair(attr.name(), val); + mlirAttr = builder_.getF32ArrayAttr( + llvm::makeArrayRef(attr.floats().begin(), attr.floats().end())); + break; case onnx::AttributeProto::INTS: - val = AttrValueType( - std::vector(attr.ints().begin(), attr.ints().end())); - return std::make_pair(attr.name(), val); + mlirAttr = builder_.getI64ArrayAttr( + llvm::makeArrayRef(attr.ints().begin(), attr.ints().end())); + break; + case onnx::AttributeProto::TENSOR: + mlirAttr = onnxTensorProtoToDenseElmAttr(builder_, attr.t()); + break; default: - assert(false && "datatype for attribute is not implemented"); + llvm_unreachable("datatype for attribute is not implemented"); break; } - llvm_unreachable("Failed to convert attribute proto to name/value pair"); + return builder_.getNamedAttr(attr.name(), mlirAttr); } std::vector ImportNodeAttributes( @@ -217,8 +173,7 @@ private: std::vector attributes; for (int i = 0; i < node.attribute_size(); ++i) { auto attr = node.attribute(i); - auto nameValPair = convertAttributeProtoToNameValuePair(attr); - attributes.push_back(convertNameValuePairToNamedAttribute(nameValPair)); + attributes.push_back(convertOnnxAttributeProtoToMlirNamedAttribute(attr)); } return attributes; }