Initialize operation arguments with ONNX model constants (#8)
* Save current state. * Include constant arguments in source. * Emit constants for Reshape second argument. * Clean-up code. * Add changes to gen_doc.py file. * Propagate constant tensor to Reshape second arg to infer shape. * Update documentation. * Eliminate constant tensor operations when lowering to KRNL dialect. * Replace ConstantTensorOp with ConstantOp. * Add comment to remove temporary Constant lowering code. * Remove unused shape inference for Constant. * Remove comment. * Remove explicit constant elimination. * Refactor code.
This commit is contained in:
parent
ba02b90e0b
commit
fe3279e721
|
@ -36,6 +36,7 @@ special_op_handler = dict([
|
||||||
("MaxPool", "ImportNodeMaxPool"),
|
("MaxPool", "ImportNodeMaxPool"),
|
||||||
("BatchNormalization", "ImportNodeBatchNormalization"),
|
("BatchNormalization", "ImportNodeBatchNormalization"),
|
||||||
("Pad", "ImportNodePad"),
|
("Pad", "ImportNodePad"),
|
||||||
|
("Reshape", "ImportNodeReshape"),
|
||||||
#("Transpose", "ImportNodeTranspose")
|
#("Transpose", "ImportNodeTranspose")
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
add_library(builder
|
add_library(builder
|
||||||
|
frontend_dialect_helper.cpp
|
||||||
|
frontend_dialect_helper.hpp
|
||||||
frontend_dialect_transformer.cpp
|
frontend_dialect_transformer.cpp
|
||||||
frontend_dialect_transformer.hpp
|
frontend_dialect_transformer.hpp
|
||||||
op_build_table.inc
|
op_build_table.inc
|
||||||
|
|
|
@ -0,0 +1,185 @@
|
||||||
|
//===------------------- frontend_dialect_helper.cpp ----------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// Helper methods for handling input ONNX models.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/builder/frontend_dialect_helper.hpp"
|
||||||
|
|
||||||
|
namespace onnf {
|
||||||
|
|
||||||
|
void replaceAll(std::string &str, const std::string &from,
|
||||||
|
const std::string &to) {
|
||||||
|
if (from.empty())
|
||||||
|
return;
|
||||||
|
size_t start_pos = 0;
|
||||||
|
while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
|
||||||
|
str.replace(start_pos, from.length(), to);
|
||||||
|
start_pos += to.length(); // In case 'to' contains 'from', like replacing
|
||||||
|
// 'x' with 'yx'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string legalize_name(std::string name) {
|
||||||
|
std::replace(name.begin(), name.end(), '/', '_');
|
||||||
|
std::replace(name.begin(), name.end(), '-', '_');
|
||||||
|
replaceAll(name, ":", "_colon_");
|
||||||
|
// If tensor name starts with a number, prepend n to make it a legal c++
|
||||||
|
// identifier.
|
||||||
|
if (name.size() > 0 && isdigit(name.at(0)))
|
||||||
|
name.insert(0, 1, 'n');
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::Value OnnxOnnfSymbolMapping::GetTensorByOnnxName(
|
||||||
|
const std::string &name) {
|
||||||
|
assert(onnx_name2onnf_tensor.find(legalize_name(name)) !=
|
||||||
|
onnx_name2onnf_tensor.end() &&
|
||||||
|
"Tensor not found");
|
||||||
|
return onnx_name2onnf_tensor.at(legalize_name(name));
|
||||||
|
}
|
||||||
|
|
||||||
|
void OnnxOnnfSymbolMapping::AddMapping(
|
||||||
|
const std::string &name, mlir::Value tensor) {
|
||||||
|
assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 &&
|
||||||
|
"Tensor already exists.");
|
||||||
|
onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool OnnxOnnfSymbolMapping::ContainKey(std::string name) {
|
||||||
|
return onnx_name2onnf_tensor.count(name) != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct TransformValueToONNXData {
|
||||||
|
static const google::protobuf::RepeatedField<T> data(
|
||||||
|
onnx::TensorProto initializer) {
|
||||||
|
return google::protobuf::RepeatedField<T>();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct TransformValueToONNXData<double> {
|
||||||
|
static const google::protobuf::RepeatedField<double> data(
|
||||||
|
onnx::TensorProto initializer) {
|
||||||
|
return initializer.double_data();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct TransformValueToONNXData<float> {
|
||||||
|
static const google::protobuf::RepeatedField<float> data(
|
||||||
|
onnx::TensorProto initializer) {
|
||||||
|
return initializer.float_data();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct TransformValueToONNXData<int32_t> {
|
||||||
|
static const google::protobuf::RepeatedField<int32_t> data(
|
||||||
|
onnx::TensorProto initializer) {
|
||||||
|
return initializer.int32_data();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct TransformValueToONNXData<int64_t> {
|
||||||
|
static const google::protobuf::RepeatedField<int64_t> data(
|
||||||
|
onnx::TensorProto initializer) {
|
||||||
|
return initializer.int64_data();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Helper method for constructing an array attribute from a model input.
|
||||||
|
template <typename T>
|
||||||
|
static T* CreateArrayAttribute(onnx::TensorProto initializer, int *size) {
|
||||||
|
if (initializer.raw_data().size()) {
|
||||||
|
// copy & take care of endianness
|
||||||
|
std::vector<char> byteInitializer;
|
||||||
|
std::copy(initializer.raw_data().begin(), initializer.raw_data().end(),
|
||||||
|
back_inserter(byteInitializer));
|
||||||
|
*size = initializer.raw_data().size() / sizeof(T);
|
||||||
|
return reinterpret_cast<T*>(&byteInitializer[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy, no need to take care of endianness
|
||||||
|
auto data = TransformValueToONNXData<T>::data(initializer);
|
||||||
|
*size = data.size();
|
||||||
|
return &data[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitializedTensorMapping::AddMapping(
|
||||||
|
std::string name, onnx::TensorProto tensor) {
|
||||||
|
assert(nameToInitializedTensor.count(name) == 0 &&
|
||||||
|
"Tensor initializer already mapped.");
|
||||||
|
nameToInitializedTensor.emplace(name, tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
bool InitializedTensorMapping::ContainKey(std::string name) {
|
||||||
|
return nameToInitializedTensor.count(name) != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor(
|
||||||
|
mlir::Location loc, mlir::OpBuilder &builder, std::string name) {
|
||||||
|
// Initializer for input.
|
||||||
|
onnx::TensorProto initializer = GetInitializedTensor(name);
|
||||||
|
|
||||||
|
// Emit ConstantOp and record the mapping between the input and
|
||||||
|
// the constant value.
|
||||||
|
mlir::ArrayAttr constantArrayAttribute;
|
||||||
|
mlir::Type elementType;
|
||||||
|
int length;
|
||||||
|
switch (initializer.data_type()) {
|
||||||
|
case (onnx::TensorProto::FLOAT): {
|
||||||
|
float *typeArray =
|
||||||
|
CreateArrayAttribute<float>(initializer, &length);
|
||||||
|
std::vector<float> arrayAttrInitializer(
|
||||||
|
typeArray, typeArray + length);
|
||||||
|
llvm::ArrayRef<float> array(typeArray, length);
|
||||||
|
constantArrayAttribute = builder.getF32ArrayAttr(array);
|
||||||
|
elementType = builder.getF32Type();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case (onnx::TensorProto::INT32): {
|
||||||
|
int32_t *typeArray =
|
||||||
|
CreateArrayAttribute<int32_t>(initializer, &length);
|
||||||
|
std::vector<int32_t> arrayAttrInitializer(
|
||||||
|
typeArray, typeArray + length);
|
||||||
|
llvm::ArrayRef<int32_t> array(typeArray, length);
|
||||||
|
constantArrayAttribute = builder.getI32ArrayAttr(array);
|
||||||
|
elementType = builder.getIntegerType(32);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case (onnx::TensorProto::INT64): {
|
||||||
|
int64_t *typeArray =
|
||||||
|
CreateArrayAttribute<int64_t>(initializer, &length);
|
||||||
|
std::vector<int64_t> arrayAttrInitializer(
|
||||||
|
typeArray, typeArray + length);
|
||||||
|
llvm::ArrayRef<int64_t> array(typeArray, length);
|
||||||
|
constantArrayAttribute = builder.getI64ArrayAttr(array);
|
||||||
|
elementType = builder.getIntegerType(64);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create empty sparse_value attribute.
|
||||||
|
llvm::ArrayRef<int64_t> array;
|
||||||
|
auto sparseValueAttribute = builder.getI64ArrayAttr(array);
|
||||||
|
|
||||||
|
// Create value attribute.
|
||||||
|
llvm::ArrayRef<int64_t> tensorDims(initializer.dims().data(),
|
||||||
|
initializer.dims().size());
|
||||||
|
mlir::Type tensorType =
|
||||||
|
mlir::RankedTensorType::get(tensorDims, elementType);
|
||||||
|
|
||||||
|
return builder.create<mlir::ONNXConstantOp>(
|
||||||
|
loc, tensorType, sparseValueAttribute,
|
||||||
|
constantArrayAttribute);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnf
|
|
@ -0,0 +1,101 @@
|
||||||
|
//===------------------- frontend_dialect_helper.hpp ----------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// Helper methods for handling input ONNX models.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
#include <regex>
|
||||||
|
#include <tuple>
|
||||||
|
|
||||||
|
#include "mlir/Analysis/Verifier.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
|
#include "mlir/IR/Attributes.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/Function.h"
|
||||||
|
#include "mlir/IR/Location.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
#include "mlir/IR/Module.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
#include "mlir/IR/Types.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/ScopedHashTable.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
#include "src/dialect/onnx/onnx_ops.hpp"
|
||||||
|
#include "onnx/onnx_pb.h"
|
||||||
|
|
||||||
|
namespace onnf {
|
||||||
|
|
||||||
|
void replaceAll(std::string &str, const std::string &from,
|
||||||
|
const std::string &to);
|
||||||
|
|
||||||
|
std::string legalize_name(std::string name);
|
||||||
|
|
||||||
|
struct OnnxOnnfSymbolMapping {
|
||||||
|
/*!
|
||||||
|
* Get MLIR tensor by onnx tensor name.
|
||||||
|
* @param name onnx tensor name.
|
||||||
|
* @return onnf tensor corresponding to `name`.
|
||||||
|
*/
|
||||||
|
mlir::Value GetTensorByOnnxName(const std::string &name);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* Add a new mapping from onnx tensor name to MLIR symbol.
|
||||||
|
* @param name onnx tensor name.
|
||||||
|
* @param tensor MLIR Value pointer.
|
||||||
|
*/
|
||||||
|
void AddMapping(const std::string &name, mlir::Value tensor);
|
||||||
|
|
||||||
|
bool ContainKey(std::string name);
|
||||||
|
|
||||||
|
private:
|
||||||
|
/*!
|
||||||
|
* mapping from onnx tensor names to MLIR tensor.
|
||||||
|
*/
|
||||||
|
std::map<std::string, mlir::Value> onnx_name2onnf_tensor;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct InitializedTensorMapping {
|
||||||
|
// Add new entry.
|
||||||
|
void AddMapping(std::string name, onnx::TensorProto tensor);
|
||||||
|
|
||||||
|
// Check if input is initialized. Not all inputs are, some of the inputs
|
||||||
|
// require input from the user and are not stored inside the ONNX model
|
||||||
|
// itself.
|
||||||
|
bool ContainKey(std::string name);
|
||||||
|
|
||||||
|
// Emit constant argument (initialized arguments) as a ConstantOp.
|
||||||
|
// This method will allow operations to use the constant data contained
|
||||||
|
// in an ONNX model as they are being compiled.
|
||||||
|
// This method enables the emission of such constant operation on demand.
|
||||||
|
//
|
||||||
|
// This will allow the propagation of shape information passed in as an
|
||||||
|
// argument to operations such as Reshape and will enable other
|
||||||
|
// optimizations such as constant folding.
|
||||||
|
mlir::Value EmitInitializerForInputTensor(mlir::Location loc,
|
||||||
|
mlir::OpBuilder &builder, std::string name);
|
||||||
|
|
||||||
|
// Get initialized tensor.
|
||||||
|
onnx::TensorProto& GetInitializedTensor(std::string name) {
|
||||||
|
assert(nameToInitializedTensor.find(name) !=
|
||||||
|
nameToInitializedTensor.end() &&
|
||||||
|
"Tensor initializer not found");
|
||||||
|
return nameToInitializedTensor.at(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Mapping from ONNX tensor name to InitializedTensor.
|
||||||
|
std::map<std::string, onnx::TensorProto> nameToInitializedTensor;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace onnf
|
|
@ -14,96 +14,20 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include <map>
|
|
||||||
#include <numeric>
|
|
||||||
#include <regex>
|
|
||||||
#include <string>
|
|
||||||
#include <tuple>
|
|
||||||
|
|
||||||
// Using backported variant.
|
// Using backported variant.
|
||||||
// bstd = backported standard library.
|
// bstd = backported standard library.
|
||||||
#include <mpark/variant.hpp>
|
#include <mpark/variant.hpp>
|
||||||
namespace bstd = mpark;
|
namespace bstd = mpark;
|
||||||
|
|
||||||
#include "mlir/Analysis/Verifier.h"
|
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
|
||||||
#include "mlir/IR/Attributes.h"
|
|
||||||
#include "mlir/IR/Builders.h"
|
|
||||||
#include "mlir/IR/Function.h"
|
|
||||||
#include "mlir/IR/Location.h"
|
|
||||||
#include "mlir/IR/MLIRContext.h"
|
|
||||||
#include "mlir/IR/Module.h"
|
|
||||||
#include "mlir/IR/StandardTypes.h"
|
|
||||||
#include "mlir/IR/Types.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
|
||||||
#include "llvm/ADT/ScopedHashTable.h"
|
|
||||||
#include "llvm/Support/raw_ostream.h"
|
|
||||||
|
|
||||||
#include "src/dialect/onnx/onnx_ops.hpp"
|
|
||||||
|
|
||||||
#include "frontend_dialect_transformer.hpp"
|
#include "frontend_dialect_transformer.hpp"
|
||||||
|
|
||||||
namespace onnf {
|
namespace onnf {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void replaceAll(std::string &str, const std::string &from,
|
/*!
|
||||||
const std::string &to) {
|
* The list of tensors initialized by the ONNX model.
|
||||||
if (from.empty())
|
*/
|
||||||
return;
|
InitializedTensorMapping initializedTensors;
|
||||||
size_t start_pos = 0;
|
|
||||||
while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
|
|
||||||
str.replace(start_pos, from.length(), to);
|
|
||||||
start_pos += to.length(); // In case 'to' contains 'from', like replacing
|
|
||||||
// 'x' with 'yx'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string legalize_name(std::string name) {
|
|
||||||
std::replace(name.begin(), name.end(), '/', '_');
|
|
||||||
std::replace(name.begin(), name.end(), '-', '_');
|
|
||||||
replaceAll(name, ":", "_colon_");
|
|
||||||
// If tensor name starts with a number, prepend n to make it a legal c++
|
|
||||||
// identifier.
|
|
||||||
if (name.size() > 0 && isdigit(name.at(0)))
|
|
||||||
name.insert(0, 1, 'n');
|
|
||||||
return name;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct OnnxOnnfSymbolMapping {
|
|
||||||
/*!
|
|
||||||
* Get MLIR tensor by onnx tensor name.
|
|
||||||
* @param name onnx tensor name.
|
|
||||||
* @return onnf tensor corresponding to `name`.
|
|
||||||
*/
|
|
||||||
mlir::Value GetTensorByOnnxName(const std::string &name) {
|
|
||||||
assert(onnx_name2onnf_tensor.find(legalize_name(name)) !=
|
|
||||||
onnx_name2onnf_tensor.end() &&
|
|
||||||
"Tensor not found");
|
|
||||||
return onnx_name2onnf_tensor.at(legalize_name(name));
|
|
||||||
}
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* Add a new mapping from onnx tensor name to MLIR symbol.
|
|
||||||
* @param name onnx tensor name.
|
|
||||||
* @param tensor MLIR Value pointer.
|
|
||||||
*/
|
|
||||||
void AddMapping(const std::string &name, mlir::Value tensor) {
|
|
||||||
assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 &&
|
|
||||||
"Tensor already exists.");
|
|
||||||
onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ContainKey(std::string name) {
|
|
||||||
return onnx_name2onnf_tensor.count(name) != 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
/*!
|
|
||||||
* mapping from onnx tensor names to MLIR tensor.
|
|
||||||
*/
|
|
||||||
std::map<std::string, mlir::Value> onnx_name2onnf_tensor;
|
|
||||||
};
|
|
||||||
|
|
||||||
class FrontendGenImpl {
|
class FrontendGenImpl {
|
||||||
public:
|
public:
|
||||||
|
@ -167,8 +91,7 @@ private:
|
||||||
* @param input onnx input tensor ValueInfoProto.
|
* @param input onnx input tensor ValueInfoProto.
|
||||||
* @param arg_types list of mlir types representing types of graph input.
|
* @param arg_types list of mlir types representing types of graph input.
|
||||||
*/
|
*/
|
||||||
void ImportInputTensorType(const onnx::ValueInfoProto &input,
|
mlir::Type ImportInputTensorType(const onnx::ValueInfoProto &input) {
|
||||||
llvm::SmallVector<mlir::Type, 4> &arg_types) {
|
|
||||||
std::vector<int64_t> dims;
|
std::vector<int64_t> dims;
|
||||||
auto shape_proto = input.type().tensor_type().shape();
|
auto shape_proto = input.type().tensor_type().shape();
|
||||||
auto input_tensor_legalized_name = legalize_name(input.name());
|
auto input_tensor_legalized_name = legalize_name(input.name());
|
||||||
|
@ -193,8 +116,7 @@ private:
|
||||||
(onnx::TensorProto_DataType)input.type().tensor_type().elem_type();
|
(onnx::TensorProto_DataType)input.type().tensor_type().elem_type();
|
||||||
mlir::Type elementType = convertONNXTypeToMLIRType(elementOnnxType);
|
mlir::Type elementType = convertONNXTypeToMLIRType(elementOnnxType);
|
||||||
llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
|
llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
|
||||||
arg_types.emplace_back(
|
return mlir::RankedTensorType::get(tensor_dims, elementType);
|
||||||
mlir::RankedTensorType::get(tensor_dims, elementType));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
|
@ -320,16 +242,11 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1,
|
void buildOutputAndOperation(const onnx::NodeProto &node,
|
||||||
int expectedNumResults = -1) {
|
std::vector<mlir::Value> inputs, int expectedNumOperands,
|
||||||
|
int expectedNumResults) {
|
||||||
bool variadicIn = expectedNumOperands == -1;
|
bool variadicIn = expectedNumOperands == -1;
|
||||||
bool variadicOut = expectedNumResults == -1;
|
bool variadicOut = expectedNumResults == -1;
|
||||||
std::vector<mlir::Value> inputs;
|
|
||||||
for (const auto &item : node.input()) {
|
|
||||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
|
||||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!variadicIn)
|
if (!variadicIn)
|
||||||
for (auto i = inputs.size(); i < expectedNumOperands; i++)
|
for (auto i = inputs.size(); i < expectedNumOperands; i++)
|
||||||
|
@ -351,6 +268,37 @@ private:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void buildOperation(const onnx::NodeProto &node,
|
||||||
|
int expectedNumOperands = -1,
|
||||||
|
int expectedNumResults = -1) {
|
||||||
|
std::vector<mlir::Value> inputs;
|
||||||
|
for (const auto &item : node.input())
|
||||||
|
if (frontend_symbols_.ContainKey(legalize_name(item)))
|
||||||
|
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||||
|
|
||||||
|
buildOutputAndOperation<T>(node, inputs, expectedNumOperands,
|
||||||
|
expectedNumResults);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ImportNodeReshape(onnx::NodeProto node, int nIn, int nOut) {
|
||||||
|
std::vector<mlir::Value> inputs;
|
||||||
|
std::string item;
|
||||||
|
for (int i = 0; i < node.input().size(); ++i) {
|
||||||
|
item = node.input()[i];
|
||||||
|
// For the second argument, check if there exists an initializer.
|
||||||
|
if (i == 1 && initializedTensors.ContainKey(legalize_name(item))) {
|
||||||
|
inputs.push_back(
|
||||||
|
initializedTensors.EmitInitializerForInputTensor(
|
||||||
|
UnknownLoc(), builder_, legalize_name(item)));
|
||||||
|
} else if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||||
|
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
buildOutputAndOperation<mlir::ONNXReshapeOp>(node, inputs, nIn, nOut);
|
||||||
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Special handle for Conv operations.
|
* Special handle for Conv operations.
|
||||||
* c++ does not allow template specialization inside a class scope
|
* c++ does not allow template specialization inside a class scope
|
||||||
|
@ -452,38 +400,52 @@ private:
|
||||||
|
|
||||||
void ImportGraph(const onnx::GraphProto &graph,
|
void ImportGraph(const onnx::GraphProto &graph,
|
||||||
const std::string &name = "main_graph") {
|
const std::string &name = "main_graph") {
|
||||||
|
// Maintain a mapping between the parameter and its initializer.
|
||||||
|
for (auto initializer : graph.initializer()) {
|
||||||
|
auto name = initializer.name();
|
||||||
|
initializedTensors.AddMapping(legalize_name(name), initializer);
|
||||||
|
}
|
||||||
|
|
||||||
// create a function for the graph
|
// create a function for the graph
|
||||||
// TODO:
|
// TODO:
|
||||||
// * get name and type for the function.
|
// * get name and type for the function.
|
||||||
// * maintain a list of the defined graph
|
// * maintain a list of the defined graph
|
||||||
llvm::SmallVector<mlir::Type, 4> arg_types;
|
llvm::SmallVector<mlir::Type, 4> arg_types;
|
||||||
|
|
||||||
// Import the input tensor types.
|
// Import the input tensor types that are not constant.
|
||||||
for (const auto &input : graph.input()) {
|
for (const auto &input : graph.input())
|
||||||
ImportInputTensorType(input, arg_types);
|
arg_types.emplace_back(ImportInputTensorType(input));
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: import the initializer
|
// Create the main function.
|
||||||
auto funcType = builder_.getFunctionType(arg_types, {});
|
auto funcType = builder_.getFunctionType(arg_types, {});
|
||||||
auto mainFunc =
|
auto mainFunc =
|
||||||
mlir::FuncOp::create(UnknownLoc(), name, funcType, /* attrs = */ {});
|
mlir::FuncOp::create(UnknownLoc(), name, funcType, /* attrs = */ {});
|
||||||
|
|
||||||
|
// Emit the entry point operation which specifies the number of user
|
||||||
|
// inputs and outputs.
|
||||||
auto entryPoint = mlir::ONNXEntryPointOp::create(
|
auto entryPoint = mlir::ONNXEntryPointOp::create(
|
||||||
UnknownLoc(), mainFunc, /*numInputs=*/graph.input().size(),
|
UnknownLoc(), mainFunc,
|
||||||
|
/*numInputs=*/graph.input().size() - graph.initializer().size(),
|
||||||
/*numOutputs=*/graph.output().size());
|
/*numOutputs=*/graph.output().size());
|
||||||
|
|
||||||
|
// Get the entru block inside the main function and set the insertion point
|
||||||
|
// to it.
|
||||||
auto &entryBlock = *mainFunc.addEntryBlock();
|
auto &entryBlock = *mainFunc.addEntryBlock();
|
||||||
builder_.setInsertionPointToStart(&entryBlock);
|
builder_.setInsertionPointToStart(&entryBlock);
|
||||||
|
|
||||||
module_.push_back(mainFunc);
|
module_.push_back(mainFunc);
|
||||||
module_.push_back(entryPoint);
|
module_.push_back(entryPoint);
|
||||||
|
|
||||||
for (auto it : llvm::zip(graph.input(), entryBlock.getArguments())) {
|
// Map graph inputs to entry block arguments.
|
||||||
ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it));
|
for (int i = 0; i < graph.input().size(); ++i)
|
||||||
}
|
ImportInputTensorSymbol(
|
||||||
|
graph.input()[i], entryBlock.getArguments()[i]);
|
||||||
|
|
||||||
|
// Create a NoneTyped constant to be used for optional operation inputs
|
||||||
|
// which are not used.
|
||||||
|
none_ = builder_.create<mlir::ConstantOp>(UnknownLoc(),
|
||||||
|
builder_.getUnitAttr());
|
||||||
|
|
||||||
// Create a NoneTyped constant.
|
|
||||||
none_ =
|
|
||||||
builder_.create<mlir::ConstantOp>(UnknownLoc(), builder_.getUnitAttr());
|
|
||||||
// Import nodes in the graph.
|
// Import nodes in the graph.
|
||||||
for (const auto &item : graph.node()) {
|
for (const auto &item : graph.node()) {
|
||||||
ImportNode(item);
|
ImportNode(item);
|
||||||
|
@ -509,13 +471,6 @@ private:
|
||||||
|
|
||||||
namespace onnf {
|
namespace onnf {
|
||||||
|
|
||||||
mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model) {
|
|
||||||
mlir::MLIRContext context;
|
|
||||||
FrontendGenImpl myONNXGen(context);
|
|
||||||
auto module = myONNXGen.ImportONNXModel(model);
|
|
||||||
return module;
|
|
||||||
}
|
|
||||||
|
|
||||||
void ImportFrontendModelFile(std::string model_fname,
|
void ImportFrontendModelFile(std::string model_fname,
|
||||||
mlir::MLIRContext &context,
|
mlir::MLIRContext &context,
|
||||||
mlir::OwningModuleRef &module) {
|
mlir::OwningModuleRef &module) {
|
||||||
|
|
|
@ -18,6 +18,8 @@
|
||||||
|
|
||||||
#include "onnx/onnx_pb.h"
|
#include "onnx/onnx_pb.h"
|
||||||
|
|
||||||
|
#include "src/builder/frontend_dialect_helper.hpp"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class MLIRContext;
|
class MLIRContext;
|
||||||
class OwningModuleRef;
|
class OwningModuleRef;
|
||||||
|
@ -28,13 +30,6 @@ class OwningModuleRef;
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
namespace onnf {
|
namespace onnf {
|
||||||
/*!
|
|
||||||
* Import an ONNX model into ONNF's ONNX Dialect.
|
|
||||||
* @param model onnx model.
|
|
||||||
* @return MLIR::module generated for the ONNX model.
|
|
||||||
*/
|
|
||||||
mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model);
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Import an ONNX model file into ONNF's ONNX Dialect.
|
* Import an ONNX model file into ONNF's ONNX Dialect.
|
||||||
* @param model_fname file name pointing to the onnx model protobuf.
|
* @param model_fname file name pointing to the onnx model protobuf.
|
||||||
|
|
|
@ -224,7 +224,7 @@ if (opName == "ReduceSumSquare")
|
||||||
if (opName == "Relu")
|
if (opName == "Relu")
|
||||||
return buildOperation<mlir::ONNXReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
return buildOperation<mlir::ONNXReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Reshape")
|
if (opName == "Reshape")
|
||||||
return buildOperation<mlir::ONNXReshapeOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
return ImportNodeReshape(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "Resize")
|
if (opName == "Resize")
|
||||||
return buildOperation<mlir::ONNXResizeOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
|
return buildOperation<mlir::ONNXResizeOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
|
||||||
if (opName == "ReverseSequence")
|
if (opName == "ReverseSequence")
|
||||||
|
|
|
@ -61,7 +61,8 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||||
rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType));
|
rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType));
|
||||||
SmallVector<Value, 4> DimInfo;
|
SmallVector<Value, 4> DimInfo;
|
||||||
for (int i = 0; i < memRefShape.size(); ++i) {
|
for (int i = 0; i < memRefShape.size(); ++i) {
|
||||||
Value index = emitConstantOp(rewriter, loc, rewriter.getIndexType(), i);
|
Value index =
|
||||||
|
emitConstantOp(rewriter, loc, rewriter.getIndexType(), i);
|
||||||
// Load index from array of indices.
|
// Load index from array of indices.
|
||||||
Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
|
Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
|
||||||
// If a dimension is zero, the actual dimension value is taken from the
|
// If a dimension is zero, the actual dimension value is taken from the
|
||||||
|
|
|
@ -202,5 +202,4 @@ def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstantValuePad",
|
||||||
"FloatAttr constant_value, StringAttr mode">];
|
"FloatAttr constant_value, StringAttr mode">];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#endif // ONNX_OPS
|
#endif // ONNX_OPS
|
||||||
|
|
|
@ -656,9 +656,27 @@ void ONNXReshapeOp::inferShapes() {
|
||||||
if (outputRank < 0)
|
if (outputRank < 0)
|
||||||
emitError("Shape tensor must have constant shape");
|
emitError("Shape tensor must have constant shape");
|
||||||
|
|
||||||
SmallVector<int64_t, 2> dims;
|
// Check if second argument of ReshapeOp is a constant.
|
||||||
for (int i = 0; i < outputRank; ++i)
|
// Get operation that defines the second argument. If this operation is a
|
||||||
dims.emplace_back(-1);
|
// `ConstantTensor` operation, the shape of this `Reshape` operation
|
||||||
|
// resides in the `value` attribute of the `ConstantTensor` operation.
|
||||||
|
auto *secondArgDefiningOp = (*getODSOperands(1).begin()).getDefiningOp();
|
||||||
|
auto constantOp =
|
||||||
|
dyn_cast_or_null<mlir::ONNXConstantOp>(secondArgDefiningOp);
|
||||||
|
|
||||||
|
SmallVector<int64_t, 2> dims(outputRank, -1);
|
||||||
|
if (constantOp) {
|
||||||
|
ArrayAttr valueAttribute = constantOp.valueAttr().dyn_cast<ArrayAttr>();
|
||||||
|
|
||||||
|
if (!valueAttribute)
|
||||||
|
emitError("ArrayAttr expected");
|
||||||
|
|
||||||
|
if (valueAttribute.getValue().size() != outputRank)
|
||||||
|
emitError("Constant value must have same rank as output");
|
||||||
|
|
||||||
|
for (int i=0; i<outputRank; ++i)
|
||||||
|
dims[i] = valueAttribute.getValue()[i].cast<IntegerAttr>().getInt();
|
||||||
|
}
|
||||||
|
|
||||||
getResult().setType(
|
getResult().setType(
|
||||||
RankedTensorType::get(dims, inputTensorTy.getElementType()));
|
RankedTensorType::get(dims, inputTensorTy.getElementType()));
|
||||||
|
|
|
@ -91,7 +91,7 @@ struct SplitConvOpPattern : public RewritePattern {
|
||||||
1, context) {}
|
1, context) {}
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(Operation *op,
|
PatternMatchResult matchAndRewrite(Operation *op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
// If convolution does not use padding then no rewrite is required.
|
// If convolution does not use padding then no rewrite is required.
|
||||||
|
@ -173,7 +173,7 @@ void ONNXMaxPoolSingleOutOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
results.insert<MaxPoolSingleOutOpPaddingPattern>(context);
|
results.insert<MaxPoolSingleOutOpPaddingPattern>(context);
|
||||||
}
|
}
|
||||||
/// on the ONNXReduceSumSquareOp.
|
/// on the ONNXConvNoBiasOp.
|
||||||
void ONNXConvNoBiasOp::getCanonicalizationPatterns(
|
void ONNXConvNoBiasOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
results.insert<SplitConvOpPattern>(context);
|
results.insert<SplitConvOpPattern>(context);
|
||||||
|
|
Loading…
Reference in New Issue