Initialize operation arguments with ONNX model constants (#8)

* Save current state.

* Include constant arguments in source.

* Emit constants for Reshape second argument.

* Clean-up code.

* Add changes to gen_doc.py file.

* Propagate constant tensor to Reshape second arg to infer shape.

* Update documentation.

* Eliminate constant tensor operations when lowering to KRNL dialect.

* Replace ConstantTensorOp with ConstantOp.

* Add comment to remove temporary Constant lowering code.

* Remove unused shape inference for Constant.

* Remove comment.

* Remove explicit constant elimination.

* Refactor code.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-03-10 14:46:35 -04:00 committed by GitHub
parent ba02b90e0b
commit fe3279e721
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 383 additions and 126 deletions

View File

@ -36,6 +36,7 @@ special_op_handler = dict([
("MaxPool", "ImportNodeMaxPool"), ("MaxPool", "ImportNodeMaxPool"),
("BatchNormalization", "ImportNodeBatchNormalization"), ("BatchNormalization", "ImportNodeBatchNormalization"),
("Pad", "ImportNodePad"), ("Pad", "ImportNodePad"),
("Reshape", "ImportNodeReshape"),
#("Transpose", "ImportNodeTranspose") #("Transpose", "ImportNodeTranspose")
]) ])

View File

@ -1,4 +1,6 @@
add_library(builder add_library(builder
frontend_dialect_helper.cpp
frontend_dialect_helper.hpp
frontend_dialect_transformer.cpp frontend_dialect_transformer.cpp
frontend_dialect_transformer.hpp frontend_dialect_transformer.hpp
op_build_table.inc op_build_table.inc

View File

@ -0,0 +1,185 @@
//===------------------- frontend_dialect_helper.cpp ----------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// Helper methods for handling input ONNX models.
//
//===----------------------------------------------------------------------===//
#include "src/builder/frontend_dialect_helper.hpp"
namespace onnf {
void replaceAll(std::string &str, const std::string &from,
const std::string &to) {
if (from.empty())
return;
size_t start_pos = 0;
while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
str.replace(start_pos, from.length(), to);
start_pos += to.length(); // In case 'to' contains 'from', like replacing
// 'x' with 'yx'
}
}
std::string legalize_name(std::string name) {
std::replace(name.begin(), name.end(), '/', '_');
std::replace(name.begin(), name.end(), '-', '_');
replaceAll(name, ":", "_colon_");
// If tensor name starts with a number, prepend n to make it a legal c++
// identifier.
if (name.size() > 0 && isdigit(name.at(0)))
name.insert(0, 1, 'n');
return name;
}
mlir::Value OnnxOnnfSymbolMapping::GetTensorByOnnxName(
const std::string &name) {
assert(onnx_name2onnf_tensor.find(legalize_name(name)) !=
onnx_name2onnf_tensor.end() &&
"Tensor not found");
return onnx_name2onnf_tensor.at(legalize_name(name));
}
void OnnxOnnfSymbolMapping::AddMapping(
const std::string &name, mlir::Value tensor) {
assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 &&
"Tensor already exists.");
onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
}
bool OnnxOnnfSymbolMapping::ContainKey(std::string name) {
return onnx_name2onnf_tensor.count(name) != 0;
}
template <typename T>
struct TransformValueToONNXData {
static const google::protobuf::RepeatedField<T> data(
onnx::TensorProto initializer) {
return google::protobuf::RepeatedField<T>();
}
};
template <>
struct TransformValueToONNXData<double> {
static const google::protobuf::RepeatedField<double> data(
onnx::TensorProto initializer) {
return initializer.double_data();
}
};
template <>
struct TransformValueToONNXData<float> {
static const google::protobuf::RepeatedField<float> data(
onnx::TensorProto initializer) {
return initializer.float_data();
}
};
template <>
struct TransformValueToONNXData<int32_t> {
static const google::protobuf::RepeatedField<int32_t> data(
onnx::TensorProto initializer) {
return initializer.int32_data();
}
};
template <>
struct TransformValueToONNXData<int64_t> {
static const google::protobuf::RepeatedField<int64_t> data(
onnx::TensorProto initializer) {
return initializer.int64_data();
}
};
// Helper method for constructing an array attribute from a model input.
template <typename T>
static T* CreateArrayAttribute(onnx::TensorProto initializer, int *size) {
if (initializer.raw_data().size()) {
// copy & take care of endianness
std::vector<char> byteInitializer;
std::copy(initializer.raw_data().begin(), initializer.raw_data().end(),
back_inserter(byteInitializer));
*size = initializer.raw_data().size() / sizeof(T);
return reinterpret_cast<T*>(&byteInitializer[0]);
}
// copy, no need to take care of endianness
auto data = TransformValueToONNXData<T>::data(initializer);
*size = data.size();
return &data[0];
}
void InitializedTensorMapping::AddMapping(
std::string name, onnx::TensorProto tensor) {
assert(nameToInitializedTensor.count(name) == 0 &&
"Tensor initializer already mapped.");
nameToInitializedTensor.emplace(name, tensor);
}
bool InitializedTensorMapping::ContainKey(std::string name) {
return nameToInitializedTensor.count(name) != 0;
}
mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor(
mlir::Location loc, mlir::OpBuilder &builder, std::string name) {
// Initializer for input.
onnx::TensorProto initializer = GetInitializedTensor(name);
// Emit ConstantOp and record the mapping between the input and
// the constant value.
mlir::ArrayAttr constantArrayAttribute;
mlir::Type elementType;
int length;
switch (initializer.data_type()) {
case (onnx::TensorProto::FLOAT): {
float *typeArray =
CreateArrayAttribute<float>(initializer, &length);
std::vector<float> arrayAttrInitializer(
typeArray, typeArray + length);
llvm::ArrayRef<float> array(typeArray, length);
constantArrayAttribute = builder.getF32ArrayAttr(array);
elementType = builder.getF32Type();
break;
}
case (onnx::TensorProto::INT32): {
int32_t *typeArray =
CreateArrayAttribute<int32_t>(initializer, &length);
std::vector<int32_t> arrayAttrInitializer(
typeArray, typeArray + length);
llvm::ArrayRef<int32_t> array(typeArray, length);
constantArrayAttribute = builder.getI32ArrayAttr(array);
elementType = builder.getIntegerType(32);
break;
}
case (onnx::TensorProto::INT64): {
int64_t *typeArray =
CreateArrayAttribute<int64_t>(initializer, &length);
std::vector<int64_t> arrayAttrInitializer(
typeArray, typeArray + length);
llvm::ArrayRef<int64_t> array(typeArray, length);
constantArrayAttribute = builder.getI64ArrayAttr(array);
elementType = builder.getIntegerType(64);
break;
}
}
// Create empty sparse_value attribute.
llvm::ArrayRef<int64_t> array;
auto sparseValueAttribute = builder.getI64ArrayAttr(array);
// Create value attribute.
llvm::ArrayRef<int64_t> tensorDims(initializer.dims().data(),
initializer.dims().size());
mlir::Type tensorType =
mlir::RankedTensorType::get(tensorDims, elementType);
return builder.create<mlir::ONNXConstantOp>(
loc, tensorType, sparseValueAttribute,
constantArrayAttribute);
}
} // namespace onnf

View File

@ -0,0 +1,101 @@
//===------------------- frontend_dialect_helper.hpp ----------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// Helper methods for handling input ONNX models.
//
//===----------------------------------------------------------------------===//
#pragma once
#include <numeric>
#include <regex>
#include <tuple>
#include "mlir/Analysis/Verifier.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/raw_ostream.h"
#include "src/dialect/onnx/onnx_ops.hpp"
#include "onnx/onnx_pb.h"
namespace onnf {
void replaceAll(std::string &str, const std::string &from,
const std::string &to);
std::string legalize_name(std::string name);
struct OnnxOnnfSymbolMapping {
/*!
* Get MLIR tensor by onnx tensor name.
* @param name onnx tensor name.
* @return onnf tensor corresponding to `name`.
*/
mlir::Value GetTensorByOnnxName(const std::string &name);
/*!
* Add a new mapping from onnx tensor name to MLIR symbol.
* @param name onnx tensor name.
* @param tensor MLIR Value pointer.
*/
void AddMapping(const std::string &name, mlir::Value tensor);
bool ContainKey(std::string name);
private:
/*!
* mapping from onnx tensor names to MLIR tensor.
*/
std::map<std::string, mlir::Value> onnx_name2onnf_tensor;
};
struct InitializedTensorMapping {
// Add new entry.
void AddMapping(std::string name, onnx::TensorProto tensor);
// Check if input is initialized. Not all inputs are, some of the inputs
// require input from the user and are not stored inside the ONNX model
// itself.
bool ContainKey(std::string name);
// Emit constant argument (initialized arguments) as a ConstantOp.
// This method will allow operations to use the constant data contained
// in an ONNX model as they are being compiled.
// This method enables the emission of such constant operation on demand.
//
// This will allow the propagation of shape information passed in as an
// argument to operations such as Reshape and will enable other
// optimizations such as constant folding.
mlir::Value EmitInitializerForInputTensor(mlir::Location loc,
mlir::OpBuilder &builder, std::string name);
// Get initialized tensor.
onnx::TensorProto& GetInitializedTensor(std::string name) {
assert(nameToInitializedTensor.find(name) !=
nameToInitializedTensor.end() &&
"Tensor initializer not found");
return nameToInitializedTensor.at(name);
}
private:
// Mapping from ONNX tensor name to InitializedTensor.
std::map<std::string, onnx::TensorProto> nameToInitializedTensor;
};
} // namespace onnf

View File

@ -14,96 +14,20 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include <map>
#include <numeric>
#include <regex>
#include <string>
#include <tuple>
// Using backported variant. // Using backported variant.
// bstd = backported standard library. // bstd = backported standard library.
#include <mpark/variant.hpp> #include <mpark/variant.hpp>
namespace bstd = mpark; namespace bstd = mpark;
#include "mlir/Analysis/Verifier.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/raw_ostream.h"
#include "src/dialect/onnx/onnx_ops.hpp"
#include "frontend_dialect_transformer.hpp" #include "frontend_dialect_transformer.hpp"
namespace onnf { namespace onnf {
namespace { namespace {
void replaceAll(std::string &str, const std::string &from,
const std::string &to) {
if (from.empty())
return;
size_t start_pos = 0;
while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
str.replace(start_pos, from.length(), to);
start_pos += to.length(); // In case 'to' contains 'from', like replacing
// 'x' with 'yx'
}
}
std::string legalize_name(std::string name) {
std::replace(name.begin(), name.end(), '/', '_');
std::replace(name.begin(), name.end(), '-', '_');
replaceAll(name, ":", "_colon_");
// If tensor name starts with a number, prepend n to make it a legal c++
// identifier.
if (name.size() > 0 && isdigit(name.at(0)))
name.insert(0, 1, 'n');
return name;
}
struct OnnxOnnfSymbolMapping {
/*! /*!
* Get MLIR tensor by onnx tensor name. * The list of tensors initialized by the ONNX model.
* @param name onnx tensor name.
* @return onnf tensor corresponding to `name`.
*/ */
mlir::Value GetTensorByOnnxName(const std::string &name) { InitializedTensorMapping initializedTensors;
assert(onnx_name2onnf_tensor.find(legalize_name(name)) !=
onnx_name2onnf_tensor.end() &&
"Tensor not found");
return onnx_name2onnf_tensor.at(legalize_name(name));
}
/*!
* Add a new mapping from onnx tensor name to MLIR symbol.
* @param name onnx tensor name.
* @param tensor MLIR Value pointer.
*/
void AddMapping(const std::string &name, mlir::Value tensor) {
assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 &&
"Tensor already exists.");
onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
}
bool ContainKey(std::string name) {
return onnx_name2onnf_tensor.count(name) != 0;
}
private:
/*!
* mapping from onnx tensor names to MLIR tensor.
*/
std::map<std::string, mlir::Value> onnx_name2onnf_tensor;
};
class FrontendGenImpl { class FrontendGenImpl {
public: public:
@ -167,8 +91,7 @@ private:
* @param input onnx input tensor ValueInfoProto. * @param input onnx input tensor ValueInfoProto.
* @param arg_types list of mlir types representing types of graph input. * @param arg_types list of mlir types representing types of graph input.
*/ */
void ImportInputTensorType(const onnx::ValueInfoProto &input, mlir::Type ImportInputTensorType(const onnx::ValueInfoProto &input) {
llvm::SmallVector<mlir::Type, 4> &arg_types) {
std::vector<int64_t> dims; std::vector<int64_t> dims;
auto shape_proto = input.type().tensor_type().shape(); auto shape_proto = input.type().tensor_type().shape();
auto input_tensor_legalized_name = legalize_name(input.name()); auto input_tensor_legalized_name = legalize_name(input.name());
@ -193,8 +116,7 @@ private:
(onnx::TensorProto_DataType)input.type().tensor_type().elem_type(); (onnx::TensorProto_DataType)input.type().tensor_type().elem_type();
mlir::Type elementType = convertONNXTypeToMLIRType(elementOnnxType); mlir::Type elementType = convertONNXTypeToMLIRType(elementOnnxType);
llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size()); llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
arg_types.emplace_back( return mlir::RankedTensorType::get(tensor_dims, elementType);
mlir::RankedTensorType::get(tensor_dims, elementType));
} }
/*! /*!
@ -320,16 +242,11 @@ private:
} }
template <typename T> template <typename T>
void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1, void buildOutputAndOperation(const onnx::NodeProto &node,
int expectedNumResults = -1) { std::vector<mlir::Value> inputs, int expectedNumOperands,
int expectedNumResults) {
bool variadicIn = expectedNumOperands == -1; bool variadicIn = expectedNumOperands == -1;
bool variadicOut = expectedNumResults == -1; bool variadicOut = expectedNumResults == -1;
std::vector<mlir::Value> inputs;
for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
}
}
if (!variadicIn) if (!variadicIn)
for (auto i = inputs.size(); i < expectedNumOperands; i++) for (auto i = inputs.size(); i < expectedNumOperands; i++)
@ -351,6 +268,37 @@ private:
} }
} }
template <typename T>
void buildOperation(const onnx::NodeProto &node,
int expectedNumOperands = -1,
int expectedNumResults = -1) {
std::vector<mlir::Value> inputs;
for (const auto &item : node.input())
if (frontend_symbols_.ContainKey(legalize_name(item)))
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
buildOutputAndOperation<T>(node, inputs, expectedNumOperands,
expectedNumResults);
}
void ImportNodeReshape(onnx::NodeProto node, int nIn, int nOut) {
std::vector<mlir::Value> inputs;
std::string item;
for (int i = 0; i < node.input().size(); ++i) {
item = node.input()[i];
// For the second argument, check if there exists an initializer.
if (i == 1 && initializedTensors.ContainKey(legalize_name(item))) {
inputs.push_back(
initializedTensors.EmitInitializerForInputTensor(
UnknownLoc(), builder_, legalize_name(item)));
} else if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
}
}
buildOutputAndOperation<mlir::ONNXReshapeOp>(node, inputs, nIn, nOut);
}
/*! /*!
* Special handle for Conv operations. * Special handle for Conv operations.
* c++ does not allow template specialization inside a class scope * c++ does not allow template specialization inside a class scope
@ -452,38 +400,52 @@ private:
void ImportGraph(const onnx::GraphProto &graph, void ImportGraph(const onnx::GraphProto &graph,
const std::string &name = "main_graph") { const std::string &name = "main_graph") {
// Maintain a mapping between the parameter and its initializer.
for (auto initializer : graph.initializer()) {
auto name = initializer.name();
initializedTensors.AddMapping(legalize_name(name), initializer);
}
// create a function for the graph // create a function for the graph
// TODO: // TODO:
// * get name and type for the function. // * get name and type for the function.
// * maintain a list of the defined graph // * maintain a list of the defined graph
llvm::SmallVector<mlir::Type, 4> arg_types; llvm::SmallVector<mlir::Type, 4> arg_types;
// Import the input tensor types. // Import the input tensor types that are not constant.
for (const auto &input : graph.input()) { for (const auto &input : graph.input())
ImportInputTensorType(input, arg_types); arg_types.emplace_back(ImportInputTensorType(input));
}
// TODO: import the initializer // Create the main function.
auto funcType = builder_.getFunctionType(arg_types, {}); auto funcType = builder_.getFunctionType(arg_types, {});
auto mainFunc = auto mainFunc =
mlir::FuncOp::create(UnknownLoc(), name, funcType, /* attrs = */ {}); mlir::FuncOp::create(UnknownLoc(), name, funcType, /* attrs = */ {});
// Emit the entry point operation which specifies the number of user
// inputs and outputs.
auto entryPoint = mlir::ONNXEntryPointOp::create( auto entryPoint = mlir::ONNXEntryPointOp::create(
UnknownLoc(), mainFunc, /*numInputs=*/graph.input().size(), UnknownLoc(), mainFunc,
/*numInputs=*/graph.input().size() - graph.initializer().size(),
/*numOutputs=*/graph.output().size()); /*numOutputs=*/graph.output().size());
// Get the entru block inside the main function and set the insertion point
// to it.
auto &entryBlock = *mainFunc.addEntryBlock(); auto &entryBlock = *mainFunc.addEntryBlock();
builder_.setInsertionPointToStart(&entryBlock); builder_.setInsertionPointToStart(&entryBlock);
module_.push_back(mainFunc); module_.push_back(mainFunc);
module_.push_back(entryPoint); module_.push_back(entryPoint);
for (auto it : llvm::zip(graph.input(), entryBlock.getArguments())) { // Map graph inputs to entry block arguments.
ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it)); for (int i = 0; i < graph.input().size(); ++i)
} ImportInputTensorSymbol(
graph.input()[i], entryBlock.getArguments()[i]);
// Create a NoneTyped constant to be used for optional operation inputs
// which are not used.
none_ = builder_.create<mlir::ConstantOp>(UnknownLoc(),
builder_.getUnitAttr());
// Create a NoneTyped constant.
none_ =
builder_.create<mlir::ConstantOp>(UnknownLoc(), builder_.getUnitAttr());
// Import nodes in the graph. // Import nodes in the graph.
for (const auto &item : graph.node()) { for (const auto &item : graph.node()) {
ImportNode(item); ImportNode(item);
@ -509,13 +471,6 @@ private:
namespace onnf { namespace onnf {
mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model) {
mlir::MLIRContext context;
FrontendGenImpl myONNXGen(context);
auto module = myONNXGen.ImportONNXModel(model);
return module;
}
void ImportFrontendModelFile(std::string model_fname, void ImportFrontendModelFile(std::string model_fname,
mlir::MLIRContext &context, mlir::MLIRContext &context,
mlir::OwningModuleRef &module) { mlir::OwningModuleRef &module) {

View File

@ -18,6 +18,8 @@
#include "onnx/onnx_pb.h" #include "onnx/onnx_pb.h"
#include "src/builder/frontend_dialect_helper.hpp"
namespace mlir { namespace mlir {
class MLIRContext; class MLIRContext;
class OwningModuleRef; class OwningModuleRef;
@ -28,13 +30,6 @@ class OwningModuleRef;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
namespace onnf { namespace onnf {
/*!
* Import an ONNX model into ONNF's ONNX Dialect.
* @param model onnx model.
* @return MLIR::module generated for the ONNX model.
*/
mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model);
/*! /*!
* Import an ONNX model file into ONNF's ONNX Dialect. * Import an ONNX model file into ONNF's ONNX Dialect.
* @param model_fname file name pointing to the onnx model protobuf. * @param model_fname file name pointing to the onnx model protobuf.

View File

@ -224,7 +224,7 @@ if (opName == "ReduceSumSquare")
if (opName == "Relu") if (opName == "Relu")
return buildOperation<mlir::ONNXReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); return buildOperation<mlir::ONNXReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Reshape") if (opName == "Reshape")
return buildOperation<mlir::ONNXReshapeOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); return ImportNodeReshape(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Resize") if (opName == "Resize")
return buildOperation<mlir::ONNXResizeOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); return buildOperation<mlir::ONNXResizeOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
if (opName == "ReverseSequence") if (opName == "ReverseSequence")

View File

@ -61,7 +61,8 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType)); rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType));
SmallVector<Value, 4> DimInfo; SmallVector<Value, 4> DimInfo;
for (int i = 0; i < memRefShape.size(); ++i) { for (int i = 0; i < memRefShape.size(); ++i) {
Value index = emitConstantOp(rewriter, loc, rewriter.getIndexType(), i); Value index =
emitConstantOp(rewriter, loc, rewriter.getIndexType(), i);
// Load index from array of indices. // Load index from array of indices.
Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index); Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
// If a dimension is zero, the actual dimension value is taken from the // If a dimension is zero, the actual dimension value is taken from the

View File

@ -202,5 +202,4 @@ def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstantValuePad",
"FloatAttr constant_value, StringAttr mode">]; "FloatAttr constant_value, StringAttr mode">];
} }
#endif // ONNX_OPS #endif // ONNX_OPS

View File

@ -656,9 +656,27 @@ void ONNXReshapeOp::inferShapes() {
if (outputRank < 0) if (outputRank < 0)
emitError("Shape tensor must have constant shape"); emitError("Shape tensor must have constant shape");
SmallVector<int64_t, 2> dims; // Check if second argument of ReshapeOp is a constant.
// Get operation that defines the second argument. If this operation is a
// `ConstantTensor` operation, the shape of this `Reshape` operation
// resides in the `value` attribute of the `ConstantTensor` operation.
auto *secondArgDefiningOp = (*getODSOperands(1).begin()).getDefiningOp();
auto constantOp =
dyn_cast_or_null<mlir::ONNXConstantOp>(secondArgDefiningOp);
SmallVector<int64_t, 2> dims(outputRank, -1);
if (constantOp) {
ArrayAttr valueAttribute = constantOp.valueAttr().dyn_cast<ArrayAttr>();
if (!valueAttribute)
emitError("ArrayAttr expected");
if (valueAttribute.getValue().size() != outputRank)
emitError("Constant value must have same rank as output");
for (int i=0; i<outputRank; ++i) for (int i=0; i<outputRank; ++i)
dims.emplace_back(-1); dims[i] = valueAttribute.getValue()[i].cast<IntegerAttr>().getInt();
}
getResult().setType( getResult().setType(
RankedTensorType::get(dims, inputTensorTy.getElementType())); RankedTensorType::get(dims, inputTensorTy.getElementType()));

View File

@ -173,7 +173,7 @@ void ONNXMaxPoolSingleOutOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) { OwningRewritePatternList &results, MLIRContext *context) {
results.insert<MaxPoolSingleOutOpPaddingPattern>(context); results.insert<MaxPoolSingleOutOpPaddingPattern>(context);
} }
/// on the ONNXReduceSumSquareOp. /// on the ONNXConvNoBiasOp.
void ONNXConvNoBiasOp::getCanonicalizationPatterns( void ONNXConvNoBiasOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) { OwningRewritePatternList &results, MLIRContext *context) {
results.insert<SplitConvOpPattern>(context); results.insert<SplitConvOpPattern>(context);