Initialize operation arguments with ONNX model constants (#8)
* Save current state. * Include constant arguments in source. * Emit constants for Reshape second argument. * Clean-up code. * Add changes to gen_doc.py file. * Propagate constant tensor to Reshape second arg to infer shape. * Update documentation. * Eliminate constant tensor operations when lowering to KRNL dialect. * Replace ConstantTensorOp with ConstantOp. * Add comment to remove temporary Constant lowering code. * Remove unused shape inference for Constant. * Remove comment. * Remove explicit constant elimination. * Refactor code.
This commit is contained in:
		
							parent
							
								
									ba02b90e0b
								
							
						
					
					
						commit
						fe3279e721
					
				| 
						 | 
				
			
			@ -36,6 +36,7 @@ special_op_handler = dict([
 | 
			
		|||
    ("MaxPool", "ImportNodeMaxPool"),
 | 
			
		||||
    ("BatchNormalization", "ImportNodeBatchNormalization"),
 | 
			
		||||
    ("Pad", "ImportNodePad"),
 | 
			
		||||
    ("Reshape", "ImportNodeReshape"),
 | 
			
		||||
    #("Transpose", "ImportNodeTranspose")
 | 
			
		||||
])
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,4 +1,6 @@
 | 
			
		|||
add_library(builder
 | 
			
		||||
        frontend_dialect_helper.cpp
 | 
			
		||||
        frontend_dialect_helper.hpp
 | 
			
		||||
        frontend_dialect_transformer.cpp
 | 
			
		||||
        frontend_dialect_transformer.hpp
 | 
			
		||||
        op_build_table.inc
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,185 @@
 | 
			
		|||
//===------------------- frontend_dialect_helper.cpp ----------------------===//
 | 
			
		||||
//
 | 
			
		||||
// Copyright 2019 The IBM Research Authors.
 | 
			
		||||
//
 | 
			
		||||
// =============================================================================
 | 
			
		||||
//
 | 
			
		||||
// Helper methods for handling input ONNX models.
 | 
			
		||||
//
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
#include "src/builder/frontend_dialect_helper.hpp"
 | 
			
		||||
 | 
			
		||||
namespace onnf {
 | 
			
		||||
 | 
			
		||||
void replaceAll(std::string &str, const std::string &from,
 | 
			
		||||
                const std::string &to) {
 | 
			
		||||
  if (from.empty())
 | 
			
		||||
    return;
 | 
			
		||||
  size_t start_pos = 0;
 | 
			
		||||
  while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
 | 
			
		||||
    str.replace(start_pos, from.length(), to);
 | 
			
		||||
    start_pos += to.length(); // In case 'to' contains 'from', like replacing
 | 
			
		||||
                              // 'x' with 'yx'
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::string legalize_name(std::string name) {
 | 
			
		||||
  std::replace(name.begin(), name.end(), '/', '_');
 | 
			
		||||
  std::replace(name.begin(), name.end(), '-', '_');
 | 
			
		||||
  replaceAll(name, ":", "_colon_");
 | 
			
		||||
  // If tensor name starts with a number, prepend n to make it a legal c++
 | 
			
		||||
  // identifier.
 | 
			
		||||
  if (name.size() > 0 && isdigit(name.at(0)))
 | 
			
		||||
    name.insert(0, 1, 'n');
 | 
			
		||||
  return name;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
mlir::Value OnnxOnnfSymbolMapping::GetTensorByOnnxName(
 | 
			
		||||
    const std::string &name) {
 | 
			
		||||
  assert(onnx_name2onnf_tensor.find(legalize_name(name)) !=
 | 
			
		||||
             onnx_name2onnf_tensor.end() &&
 | 
			
		||||
         "Tensor not found");
 | 
			
		||||
  return onnx_name2onnf_tensor.at(legalize_name(name));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void OnnxOnnfSymbolMapping::AddMapping(
 | 
			
		||||
    const std::string &name, mlir::Value tensor) {
 | 
			
		||||
  assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 &&
 | 
			
		||||
         "Tensor already exists.");
 | 
			
		||||
  onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool OnnxOnnfSymbolMapping::ContainKey(std::string name) {
 | 
			
		||||
  return onnx_name2onnf_tensor.count(name) != 0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
struct TransformValueToONNXData {
 | 
			
		||||
  static const google::protobuf::RepeatedField<T> data(
 | 
			
		||||
      onnx::TensorProto initializer) {
 | 
			
		||||
    return google::protobuf::RepeatedField<T>();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
struct TransformValueToONNXData<double> {
 | 
			
		||||
  static const google::protobuf::RepeatedField<double> data(
 | 
			
		||||
      onnx::TensorProto initializer) {
 | 
			
		||||
    return initializer.double_data();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
struct TransformValueToONNXData<float> {
 | 
			
		||||
  static const google::protobuf::RepeatedField<float> data(
 | 
			
		||||
      onnx::TensorProto initializer) {
 | 
			
		||||
    return initializer.float_data();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
struct TransformValueToONNXData<int32_t> {
 | 
			
		||||
  static const google::protobuf::RepeatedField<int32_t> data(
 | 
			
		||||
      onnx::TensorProto initializer) {
 | 
			
		||||
    return initializer.int32_data();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
struct TransformValueToONNXData<int64_t> {
 | 
			
		||||
  static const google::protobuf::RepeatedField<int64_t> data(
 | 
			
		||||
      onnx::TensorProto initializer) {
 | 
			
		||||
    return initializer.int64_data();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Helper method for constructing an array attribute from a model input.
 | 
			
		||||
template <typename T>
 | 
			
		||||
static T* CreateArrayAttribute(onnx::TensorProto initializer, int *size) {
 | 
			
		||||
  if (initializer.raw_data().size()) {
 | 
			
		||||
    // copy & take care of endianness
 | 
			
		||||
    std::vector<char> byteInitializer;
 | 
			
		||||
    std::copy(initializer.raw_data().begin(), initializer.raw_data().end(),
 | 
			
		||||
        back_inserter(byteInitializer));
 | 
			
		||||
    *size = initializer.raw_data().size() / sizeof(T);
 | 
			
		||||
    return reinterpret_cast<T*>(&byteInitializer[0]);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // copy, no need to take care of endianness
 | 
			
		||||
  auto data = TransformValueToONNXData<T>::data(initializer);
 | 
			
		||||
  *size = data.size();
 | 
			
		||||
  return &data[0];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void InitializedTensorMapping::AddMapping(
 | 
			
		||||
    std::string name, onnx::TensorProto tensor) {
 | 
			
		||||
  assert(nameToInitializedTensor.count(name) == 0 &&
 | 
			
		||||
         "Tensor initializer already mapped.");
 | 
			
		||||
  nameToInitializedTensor.emplace(name, tensor);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
bool InitializedTensorMapping::ContainKey(std::string name) {
 | 
			
		||||
  return nameToInitializedTensor.count(name) != 0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor(
 | 
			
		||||
    mlir::Location loc, mlir::OpBuilder &builder, std::string name) {
 | 
			
		||||
  // Initializer for input.
 | 
			
		||||
  onnx::TensorProto initializer = GetInitializedTensor(name);
 | 
			
		||||
 | 
			
		||||
  // Emit ConstantOp and record the mapping between the input and
 | 
			
		||||
  // the constant value.
 | 
			
		||||
  mlir::ArrayAttr constantArrayAttribute;
 | 
			
		||||
  mlir::Type elementType;
 | 
			
		||||
  int length;
 | 
			
		||||
  switch (initializer.data_type()) {
 | 
			
		||||
    case (onnx::TensorProto::FLOAT): {
 | 
			
		||||
      float *typeArray =
 | 
			
		||||
          CreateArrayAttribute<float>(initializer, &length);
 | 
			
		||||
      std::vector<float> arrayAttrInitializer(
 | 
			
		||||
      	typeArray, typeArray + length);
 | 
			
		||||
      llvm::ArrayRef<float> array(typeArray, length);
 | 
			
		||||
      constantArrayAttribute = builder.getF32ArrayAttr(array);
 | 
			
		||||
      elementType = builder.getF32Type();
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
    case (onnx::TensorProto::INT32): {
 | 
			
		||||
      int32_t *typeArray =
 | 
			
		||||
          CreateArrayAttribute<int32_t>(initializer, &length);
 | 
			
		||||
      std::vector<int32_t> arrayAttrInitializer(
 | 
			
		||||
      	typeArray, typeArray + length);
 | 
			
		||||
      llvm::ArrayRef<int32_t> array(typeArray, length);
 | 
			
		||||
      constantArrayAttribute = builder.getI32ArrayAttr(array);
 | 
			
		||||
      elementType = builder.getIntegerType(32);
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
    case (onnx::TensorProto::INT64): {
 | 
			
		||||
      int64_t *typeArray =
 | 
			
		||||
          CreateArrayAttribute<int64_t>(initializer, &length);
 | 
			
		||||
      std::vector<int64_t> arrayAttrInitializer(
 | 
			
		||||
      	typeArray, typeArray + length);
 | 
			
		||||
      llvm::ArrayRef<int64_t> array(typeArray, length);
 | 
			
		||||
      constantArrayAttribute = builder.getI64ArrayAttr(array);
 | 
			
		||||
      elementType = builder.getIntegerType(64);
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Create empty sparse_value attribute.
 | 
			
		||||
  llvm::ArrayRef<int64_t> array;
 | 
			
		||||
  auto sparseValueAttribute = builder.getI64ArrayAttr(array);
 | 
			
		||||
 | 
			
		||||
  // Create value attribute.
 | 
			
		||||
  llvm::ArrayRef<int64_t> tensorDims(initializer.dims().data(),
 | 
			
		||||
      initializer.dims().size());
 | 
			
		||||
  mlir::Type tensorType =
 | 
			
		||||
      mlir::RankedTensorType::get(tensorDims, elementType);
 | 
			
		||||
 | 
			
		||||
  return builder.create<mlir::ONNXConstantOp>(
 | 
			
		||||
      loc, tensorType, sparseValueAttribute,
 | 
			
		||||
      constantArrayAttribute);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace onnf
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,101 @@
 | 
			
		|||
//===------------------- frontend_dialect_helper.hpp ----------------------===//
 | 
			
		||||
//
 | 
			
		||||
// Copyright 2019 The IBM Research Authors.
 | 
			
		||||
//
 | 
			
		||||
// =============================================================================
 | 
			
		||||
//
 | 
			
		||||
// Helper methods for handling input ONNX models.
 | 
			
		||||
//
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <numeric>
 | 
			
		||||
#include <regex>
 | 
			
		||||
#include <tuple>
 | 
			
		||||
 | 
			
		||||
#include "mlir/Analysis/Verifier.h"
 | 
			
		||||
#include "mlir/Dialect/StandardOps/Ops.h"
 | 
			
		||||
#include "mlir/IR/Attributes.h"
 | 
			
		||||
#include "mlir/IR/Builders.h"
 | 
			
		||||
#include "mlir/IR/Function.h"
 | 
			
		||||
#include "mlir/IR/Location.h"
 | 
			
		||||
#include "mlir/IR/Matchers.h"
 | 
			
		||||
#include "mlir/IR/MLIRContext.h"
 | 
			
		||||
#include "mlir/IR/Module.h"
 | 
			
		||||
#include "mlir/IR/PatternMatch.h"
 | 
			
		||||
#include "mlir/IR/StandardTypes.h"
 | 
			
		||||
#include "mlir/IR/Types.h"
 | 
			
		||||
 | 
			
		||||
#include "llvm/ADT/STLExtras.h"
 | 
			
		||||
#include "llvm/ADT/ScopedHashTable.h"
 | 
			
		||||
#include "llvm/Support/raw_ostream.h"
 | 
			
		||||
 | 
			
		||||
#include "src/dialect/onnx/onnx_ops.hpp"
 | 
			
		||||
#include "onnx/onnx_pb.h"
 | 
			
		||||
 | 
			
		||||
namespace onnf {
 | 
			
		||||
 | 
			
		||||
void replaceAll(std::string &str, const std::string &from,
 | 
			
		||||
                const std::string &to);
 | 
			
		||||
 | 
			
		||||
std::string legalize_name(std::string name);
 | 
			
		||||
 | 
			
		||||
struct OnnxOnnfSymbolMapping {
 | 
			
		||||
  /*!
 | 
			
		||||
   *  Get MLIR tensor by onnx tensor name.
 | 
			
		||||
   *  @param name onnx tensor name.
 | 
			
		||||
   *  @return onnf tensor corresponding to `name`.
 | 
			
		||||
   */
 | 
			
		||||
  mlir::Value GetTensorByOnnxName(const std::string &name);
 | 
			
		||||
 | 
			
		||||
  /*!
 | 
			
		||||
   *  Add a new mapping from onnx tensor name to MLIR symbol.
 | 
			
		||||
   *  @param name onnx tensor name.
 | 
			
		||||
   *  @param tensor MLIR Value  pointer.
 | 
			
		||||
   */
 | 
			
		||||
  void AddMapping(const std::string &name, mlir::Value tensor);
 | 
			
		||||
 | 
			
		||||
  bool ContainKey(std::string name);
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  /*!
 | 
			
		||||
   *  mapping from onnx tensor names to MLIR tensor.
 | 
			
		||||
   */
 | 
			
		||||
  std::map<std::string, mlir::Value> onnx_name2onnf_tensor;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct InitializedTensorMapping {
 | 
			
		||||
  // Add new entry.
 | 
			
		||||
  void AddMapping(std::string name, onnx::TensorProto tensor);
 | 
			
		||||
 | 
			
		||||
  // Check if input is initialized. Not all inputs are, some of the inputs
 | 
			
		||||
  // require input from the user and are not stored inside the ONNX model
 | 
			
		||||
  // itself.
 | 
			
		||||
  bool ContainKey(std::string name);
 | 
			
		||||
 | 
			
		||||
  // Emit constant argument (initialized arguments) as a ConstantOp.
 | 
			
		||||
  // This method will allow operations to use the constant data contained
 | 
			
		||||
  // in an ONNX model as they are being compiled.
 | 
			
		||||
  // This method enables the emission of such constant operation on demand.
 | 
			
		||||
  //
 | 
			
		||||
  // This will allow the propagation of shape information passed in as an
 | 
			
		||||
  // argument to operations such as Reshape and will enable other
 | 
			
		||||
  // optimizations such as constant folding.
 | 
			
		||||
  mlir::Value EmitInitializerForInputTensor(mlir::Location loc,
 | 
			
		||||
  	  mlir::OpBuilder &builder, std::string name);
 | 
			
		||||
 | 
			
		||||
  // Get initialized tensor.
 | 
			
		||||
  onnx::TensorProto& GetInitializedTensor(std::string name) {
 | 
			
		||||
    assert(nameToInitializedTensor.find(name) !=
 | 
			
		||||
               nameToInitializedTensor.end() &&
 | 
			
		||||
           "Tensor initializer not found");
 | 
			
		||||
    return nameToInitializedTensor.at(name);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  // Mapping from ONNX tensor name to InitializedTensor.
 | 
			
		||||
  std::map<std::string, onnx::TensorProto> nameToInitializedTensor;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
} // namespace onnf
 | 
			
		||||
| 
						 | 
				
			
			@ -14,96 +14,20 @@
 | 
			
		|||
//
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
#include <map>
 | 
			
		||||
#include <numeric>
 | 
			
		||||
#include <regex>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <tuple>
 | 
			
		||||
 | 
			
		||||
// Using backported variant.
 | 
			
		||||
// bstd = backported standard library.
 | 
			
		||||
#include <mpark/variant.hpp>
 | 
			
		||||
namespace bstd = mpark;
 | 
			
		||||
 | 
			
		||||
#include "mlir/Analysis/Verifier.h"
 | 
			
		||||
#include "mlir/Dialect/StandardOps/Ops.h"
 | 
			
		||||
#include "mlir/IR/Attributes.h"
 | 
			
		||||
#include "mlir/IR/Builders.h"
 | 
			
		||||
#include "mlir/IR/Function.h"
 | 
			
		||||
#include "mlir/IR/Location.h"
 | 
			
		||||
#include "mlir/IR/MLIRContext.h"
 | 
			
		||||
#include "mlir/IR/Module.h"
 | 
			
		||||
#include "mlir/IR/StandardTypes.h"
 | 
			
		||||
#include "mlir/IR/Types.h"
 | 
			
		||||
 | 
			
		||||
#include "llvm/ADT/STLExtras.h"
 | 
			
		||||
#include "llvm/ADT/ScopedHashTable.h"
 | 
			
		||||
#include "llvm/Support/raw_ostream.h"
 | 
			
		||||
 | 
			
		||||
#include "src/dialect/onnx/onnx_ops.hpp"
 | 
			
		||||
 | 
			
		||||
#include "frontend_dialect_transformer.hpp"
 | 
			
		||||
 | 
			
		||||
namespace onnf {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
void replaceAll(std::string &str, const std::string &from,
 | 
			
		||||
                const std::string &to) {
 | 
			
		||||
  if (from.empty())
 | 
			
		||||
    return;
 | 
			
		||||
  size_t start_pos = 0;
 | 
			
		||||
  while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
 | 
			
		||||
    str.replace(start_pos, from.length(), to);
 | 
			
		||||
    start_pos += to.length(); // In case 'to' contains 'from', like replacing
 | 
			
		||||
                              // 'x' with 'yx'
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::string legalize_name(std::string name) {
 | 
			
		||||
  std::replace(name.begin(), name.end(), '/', '_');
 | 
			
		||||
  std::replace(name.begin(), name.end(), '-', '_');
 | 
			
		||||
  replaceAll(name, ":", "_colon_");
 | 
			
		||||
  // If tensor name starts with a number, prepend n to make it a legal c++
 | 
			
		||||
  // identifier.
 | 
			
		||||
  if (name.size() > 0 && isdigit(name.at(0)))
 | 
			
		||||
    name.insert(0, 1, 'n');
 | 
			
		||||
  return name;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct OnnxOnnfSymbolMapping {
 | 
			
		||||
  /*!
 | 
			
		||||
   *  Get MLIR tensor by onnx tensor name.
 | 
			
		||||
   *  @param name onnx tensor name.
 | 
			
		||||
   *  @return onnf tensor corresponding to `name`.
 | 
			
		||||
/*!
 | 
			
		||||
 *  The list of tensors initialized by the ONNX model.
 | 
			
		||||
 */
 | 
			
		||||
  mlir::Value GetTensorByOnnxName(const std::string &name) {
 | 
			
		||||
    assert(onnx_name2onnf_tensor.find(legalize_name(name)) !=
 | 
			
		||||
               onnx_name2onnf_tensor.end() &&
 | 
			
		||||
           "Tensor not found");
 | 
			
		||||
    return onnx_name2onnf_tensor.at(legalize_name(name));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /*!
 | 
			
		||||
   *  Add a new mapping from onnx tensor name to MLIR symbol.
 | 
			
		||||
   *  @param name onnx tensor name.
 | 
			
		||||
   *  @param tensor MLIR Value  pointer.
 | 
			
		||||
   */
 | 
			
		||||
  void AddMapping(const std::string &name, mlir::Value tensor) {
 | 
			
		||||
    assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 &&
 | 
			
		||||
           "Tensor already exists.");
 | 
			
		||||
    onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool ContainKey(std::string name) {
 | 
			
		||||
    return onnx_name2onnf_tensor.count(name) != 0;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  /*!
 | 
			
		||||
   *  mapping from onnx tensor names to MLIR tensor.
 | 
			
		||||
   */
 | 
			
		||||
  std::map<std::string, mlir::Value> onnx_name2onnf_tensor;
 | 
			
		||||
};
 | 
			
		||||
InitializedTensorMapping initializedTensors;
 | 
			
		||||
 | 
			
		||||
class FrontendGenImpl {
 | 
			
		||||
public:
 | 
			
		||||
| 
						 | 
				
			
			@ -167,8 +91,7 @@ private:
 | 
			
		|||
   * @param input onnx input tensor ValueInfoProto.
 | 
			
		||||
   * @param arg_types list of mlir types representing types of graph input.
 | 
			
		||||
   */
 | 
			
		||||
  void ImportInputTensorType(const onnx::ValueInfoProto &input,
 | 
			
		||||
                             llvm::SmallVector<mlir::Type, 4> &arg_types) {
 | 
			
		||||
  mlir::Type ImportInputTensorType(const onnx::ValueInfoProto &input) {
 | 
			
		||||
    std::vector<int64_t> dims;
 | 
			
		||||
    auto shape_proto = input.type().tensor_type().shape();
 | 
			
		||||
    auto input_tensor_legalized_name = legalize_name(input.name());
 | 
			
		||||
| 
						 | 
				
			
			@ -193,8 +116,7 @@ private:
 | 
			
		|||
        (onnx::TensorProto_DataType)input.type().tensor_type().elem_type();
 | 
			
		||||
    mlir::Type elementType = convertONNXTypeToMLIRType(elementOnnxType);
 | 
			
		||||
    llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
 | 
			
		||||
    arg_types.emplace_back(
 | 
			
		||||
        mlir::RankedTensorType::get(tensor_dims, elementType));
 | 
			
		||||
    return mlir::RankedTensorType::get(tensor_dims, elementType);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /*!
 | 
			
		||||
| 
						 | 
				
			
			@ -320,16 +242,11 @@ private:
 | 
			
		|||
  }
 | 
			
		||||
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1,
 | 
			
		||||
                      int expectedNumResults = -1) {
 | 
			
		||||
  void buildOutputAndOperation(const onnx::NodeProto &node,
 | 
			
		||||
      std::vector<mlir::Value> inputs, int expectedNumOperands,
 | 
			
		||||
      int expectedNumResults) {
 | 
			
		||||
    bool variadicIn = expectedNumOperands == -1;
 | 
			
		||||
    bool variadicOut = expectedNumResults == -1;
 | 
			
		||||
    std::vector<mlir::Value> inputs;
 | 
			
		||||
    for (const auto &item : node.input()) {
 | 
			
		||||
      if (frontend_symbols_.ContainKey(legalize_name(item))) {
 | 
			
		||||
        inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (!variadicIn)
 | 
			
		||||
      for (auto i = inputs.size(); i < expectedNumOperands; i++)
 | 
			
		||||
| 
						 | 
				
			
			@ -351,6 +268,37 @@ private:
 | 
			
		|||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  void buildOperation(const onnx::NodeProto &node,
 | 
			
		||||
                      int expectedNumOperands = -1,
 | 
			
		||||
                      int expectedNumResults = -1) {
 | 
			
		||||
    std::vector<mlir::Value> inputs;
 | 
			
		||||
    for (const auto &item : node.input())
 | 
			
		||||
      if (frontend_symbols_.ContainKey(legalize_name(item)))
 | 
			
		||||
        inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
 | 
			
		||||
 | 
			
		||||
    buildOutputAndOperation<T>(node, inputs, expectedNumOperands,
 | 
			
		||||
        expectedNumResults);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void ImportNodeReshape(onnx::NodeProto node, int nIn, int nOut) {
 | 
			
		||||
    std::vector<mlir::Value> inputs;
 | 
			
		||||
    std::string item;
 | 
			
		||||
    for (int i = 0; i < node.input().size(); ++i) {
 | 
			
		||||
      item = node.input()[i];
 | 
			
		||||
      // For the second argument, check if there exists an initializer.
 | 
			
		||||
      if (i == 1 && initializedTensors.ContainKey(legalize_name(item))) {
 | 
			
		||||
          inputs.push_back(
 | 
			
		||||
                initializedTensors.EmitInitializerForInputTensor(
 | 
			
		||||
                    UnknownLoc(), builder_, legalize_name(item)));
 | 
			
		||||
      } else if (frontend_symbols_.ContainKey(legalize_name(item))) {
 | 
			
		||||
        inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    buildOutputAndOperation<mlir::ONNXReshapeOp>(node, inputs, nIn, nOut);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /*!
 | 
			
		||||
   * Special handle for Conv operations.
 | 
			
		||||
   * c++ does not allow template specialization inside a class scope
 | 
			
		||||
| 
						 | 
				
			
			@ -452,38 +400,52 @@ private:
 | 
			
		|||
 | 
			
		||||
  void ImportGraph(const onnx::GraphProto &graph,
 | 
			
		||||
                   const std::string &name = "main_graph") {
 | 
			
		||||
    // Maintain a mapping between the parameter and its initializer.
 | 
			
		||||
    for (auto initializer : graph.initializer()) {
 | 
			
		||||
      auto name = initializer.name();
 | 
			
		||||
      initializedTensors.AddMapping(legalize_name(name), initializer);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // create a function for the graph
 | 
			
		||||
    // TODO:
 | 
			
		||||
    //  * get name and type for the function.
 | 
			
		||||
    //  * maintain a list of the defined graph
 | 
			
		||||
    llvm::SmallVector<mlir::Type, 4> arg_types;
 | 
			
		||||
 | 
			
		||||
    // Import the input tensor types.
 | 
			
		||||
    for (const auto &input : graph.input()) {
 | 
			
		||||
      ImportInputTensorType(input, arg_types);
 | 
			
		||||
    }
 | 
			
		||||
    // Import the input tensor types that are not constant.
 | 
			
		||||
    for (const auto &input : graph.input())
 | 
			
		||||
      arg_types.emplace_back(ImportInputTensorType(input));
 | 
			
		||||
 | 
			
		||||
    // TODO: import the initializer
 | 
			
		||||
    // Create the main function.
 | 
			
		||||
    auto funcType = builder_.getFunctionType(arg_types, {});
 | 
			
		||||
    auto mainFunc =
 | 
			
		||||
        mlir::FuncOp::create(UnknownLoc(), name, funcType, /* attrs = */ {});
 | 
			
		||||
 | 
			
		||||
    // Emit the entry point operation which specifies the number of user
 | 
			
		||||
    // inputs and outputs.
 | 
			
		||||
    auto entryPoint = mlir::ONNXEntryPointOp::create(
 | 
			
		||||
        UnknownLoc(), mainFunc, /*numInputs=*/graph.input().size(),
 | 
			
		||||
        UnknownLoc(), mainFunc,
 | 
			
		||||
        /*numInputs=*/graph.input().size() - graph.initializer().size(),
 | 
			
		||||
        /*numOutputs=*/graph.output().size());
 | 
			
		||||
 | 
			
		||||
    // Get the entru block inside the main function and set the insertion point
 | 
			
		||||
    // to it.
 | 
			
		||||
    auto &entryBlock = *mainFunc.addEntryBlock();
 | 
			
		||||
    builder_.setInsertionPointToStart(&entryBlock);
 | 
			
		||||
 | 
			
		||||
    module_.push_back(mainFunc);
 | 
			
		||||
    module_.push_back(entryPoint);
 | 
			
		||||
 | 
			
		||||
    for (auto it : llvm::zip(graph.input(), entryBlock.getArguments())) {
 | 
			
		||||
      ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it));
 | 
			
		||||
    }
 | 
			
		||||
    // Map graph inputs to entry block arguments.
 | 
			
		||||
    for (int i = 0; i < graph.input().size(); ++i)
 | 
			
		||||
      ImportInputTensorSymbol(
 | 
			
		||||
          graph.input()[i], entryBlock.getArguments()[i]);
 | 
			
		||||
 | 
			
		||||
    // Create a NoneTyped constant to be used for optional operation inputs
 | 
			
		||||
    // which are not used.
 | 
			
		||||
    none_ = builder_.create<mlir::ConstantOp>(UnknownLoc(),
 | 
			
		||||
        builder_.getUnitAttr());
 | 
			
		||||
 | 
			
		||||
    // Create a NoneTyped constant.
 | 
			
		||||
    none_ =
 | 
			
		||||
        builder_.create<mlir::ConstantOp>(UnknownLoc(), builder_.getUnitAttr());
 | 
			
		||||
    // Import nodes in the graph.
 | 
			
		||||
    for (const auto &item : graph.node()) {
 | 
			
		||||
      ImportNode(item);
 | 
			
		||||
| 
						 | 
				
			
			@ -509,13 +471,6 @@ private:
 | 
			
		|||
 | 
			
		||||
namespace onnf {
 | 
			
		||||
 | 
			
		||||
mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model) {
 | 
			
		||||
  mlir::MLIRContext context;
 | 
			
		||||
  FrontendGenImpl myONNXGen(context);
 | 
			
		||||
  auto module = myONNXGen.ImportONNXModel(model);
 | 
			
		||||
  return module;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ImportFrontendModelFile(std::string model_fname,
 | 
			
		||||
                             mlir::MLIRContext &context,
 | 
			
		||||
                             mlir::OwningModuleRef &module) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,6 +18,8 @@
 | 
			
		|||
 | 
			
		||||
#include "onnx/onnx_pb.h"
 | 
			
		||||
 | 
			
		||||
#include "src/builder/frontend_dialect_helper.hpp"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
class MLIRContext;
 | 
			
		||||
class OwningModuleRef;
 | 
			
		||||
| 
						 | 
				
			
			@ -28,13 +30,6 @@ class OwningModuleRef;
 | 
			
		|||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
namespace onnf {
 | 
			
		||||
/*!
 | 
			
		||||
 *  Import an ONNX model into ONNF's ONNX Dialect.
 | 
			
		||||
 *  @param model onnx model.
 | 
			
		||||
 *  @return MLIR::module generated for the ONNX model.
 | 
			
		||||
 */
 | 
			
		||||
mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model);
 | 
			
		||||
 | 
			
		||||
/*!
 | 
			
		||||
 *  Import an ONNX model file into ONNF's ONNX Dialect.
 | 
			
		||||
 *  @param model_fname file name pointing to the onnx model protobuf.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -224,7 +224,7 @@ if (opName == "ReduceSumSquare")
 | 
			
		|||
if (opName == "Relu")
 | 
			
		||||
  return buildOperation<mlir::ONNXReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
 | 
			
		||||
if (opName == "Reshape")
 | 
			
		||||
  return buildOperation<mlir::ONNXReshapeOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
 | 
			
		||||
  return ImportNodeReshape(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
 | 
			
		||||
if (opName == "Resize")
 | 
			
		||||
  return buildOperation<mlir::ONNXResizeOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
 | 
			
		||||
if (opName == "ReverseSequence")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -61,7 +61,8 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
 | 
			
		|||
          rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType));
 | 
			
		||||
      SmallVector<Value, 4> DimInfo;
 | 
			
		||||
      for (int i = 0; i < memRefShape.size(); ++i) {
 | 
			
		||||
        Value index = emitConstantOp(rewriter, loc, rewriter.getIndexType(), i);
 | 
			
		||||
        Value index =
 | 
			
		||||
            emitConstantOp(rewriter, loc, rewriter.getIndexType(), i);
 | 
			
		||||
        // Load index from array of indices.
 | 
			
		||||
        Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
 | 
			
		||||
        // If a dimension is zero, the actual dimension value is taken from the
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -202,5 +202,4 @@ def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstantValuePad",
 | 
			
		|||
                            "FloatAttr constant_value, StringAttr mode">];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#endif // ONNX_OPS
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -656,9 +656,27 @@ void ONNXReshapeOp::inferShapes() {
 | 
			
		|||
  if (outputRank < 0)
 | 
			
		||||
    emitError("Shape tensor must have constant shape");
 | 
			
		||||
 | 
			
		||||
  SmallVector<int64_t, 2> dims;
 | 
			
		||||
  for (int i = 0; i < outputRank; ++i)
 | 
			
		||||
    dims.emplace_back(-1);
 | 
			
		||||
  // Check if second argument of ReshapeOp is a constant.
 | 
			
		||||
  // Get operation that defines the second argument. If this operation is a
 | 
			
		||||
  // `ConstantTensor` operation, the shape of this `Reshape` operation
 | 
			
		||||
  // resides in the `value` attribute of the `ConstantTensor` operation.
 | 
			
		||||
  auto *secondArgDefiningOp = (*getODSOperands(1).begin()).getDefiningOp();
 | 
			
		||||
  auto constantOp =
 | 
			
		||||
      dyn_cast_or_null<mlir::ONNXConstantOp>(secondArgDefiningOp);
 | 
			
		||||
 | 
			
		||||
  SmallVector<int64_t, 2> dims(outputRank, -1);
 | 
			
		||||
  if (constantOp) {
 | 
			
		||||
    ArrayAttr valueAttribute = constantOp.valueAttr().dyn_cast<ArrayAttr>();
 | 
			
		||||
 | 
			
		||||
    if (!valueAttribute)
 | 
			
		||||
      emitError("ArrayAttr expected");
 | 
			
		||||
 | 
			
		||||
    if (valueAttribute.getValue().size() != outputRank)
 | 
			
		||||
      emitError("Constant value must have same rank as output");
 | 
			
		||||
 | 
			
		||||
    for (int i=0; i<outputRank; ++i)
 | 
			
		||||
      dims[i] = valueAttribute.getValue()[i].cast<IntegerAttr>().getInt();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  getResult().setType(
 | 
			
		||||
      RankedTensorType::get(dims, inputTensorTy.getElementType()));
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -173,7 +173,7 @@ void ONNXMaxPoolSingleOutOp::getCanonicalizationPatterns(
 | 
			
		|||
    OwningRewritePatternList &results, MLIRContext *context) {
 | 
			
		||||
  results.insert<MaxPoolSingleOutOpPaddingPattern>(context);
 | 
			
		||||
}
 | 
			
		||||
/// on the ONNXReduceSumSquareOp.
 | 
			
		||||
/// on the ONNXConvNoBiasOp.
 | 
			
		||||
void ONNXConvNoBiasOp::getCanonicalizationPatterns(
 | 
			
		||||
    OwningRewritePatternList &results, MLIRContext *context) {
 | 
			
		||||
  results.insert<SplitConvOpPattern>(context);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue