[NFC] Set up clang-format Github Action (#119)
* Run clang-format on all source code. * Add Clang-Format Github Action. * Apply patch produced by Clang-Format Bot. * nit. Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
		
							parent
							
								
									24343177b8
								
							
						
					
					
						commit
						7f2bffb27d
					
				| 
						 | 
				
			
			@ -0,0 +1,40 @@
 | 
			
		|||
# This is a basic workflow to help you get started with Actions
 | 
			
		||||
 | 
			
		||||
name: Clang-Format Bot
 | 
			
		||||
 | 
			
		||||
# Controls when the action will run. Triggers the workflow on push or pull request
 | 
			
		||||
# events but only for the master branch
 | 
			
		||||
on:
 | 
			
		||||
  push:
 | 
			
		||||
    branches: [ master ]
 | 
			
		||||
  pull_request:
 | 
			
		||||
    branches: [ master ]
 | 
			
		||||
 | 
			
		||||
# A workflow run is made up of one or more jobs that can run sequentially or in parallel
 | 
			
		||||
jobs:
 | 
			
		||||
  # This workflow contains a single job called "build"
 | 
			
		||||
  build:
 | 
			
		||||
    # The type of runner that the job will run on
 | 
			
		||||
    runs-on: ubuntu-latest
 | 
			
		||||
 | 
			
		||||
    # Steps represent a sequence of tasks that will be executed as part of the job
 | 
			
		||||
    steps:
 | 
			
		||||
    # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
 | 
			
		||||
    - uses: actions/checkout@v2
 | 
			
		||||
    - name: clang-format lint
 | 
			
		||||
      uses: DoozyX/clang-format-lint-action@v0.5
 | 
			
		||||
      with:
 | 
			
		||||
        # Source folder to check formatting
 | 
			
		||||
        source: ./src
 | 
			
		||||
        # Version of clang-format
 | 
			
		||||
        clangFormatVersion: 9 # optional, default is 9
 | 
			
		||||
    
 | 
			
		||||
    # Runs a single command using the runners shell
 | 
			
		||||
    - name: Run a one-line script
 | 
			
		||||
      run: echo Hello, world!
 | 
			
		||||
 | 
			
		||||
    # Runs a set of commands using the runners shell
 | 
			
		||||
    - name: Run a multi-line script
 | 
			
		||||
      run: |
 | 
			
		||||
        echo Add other actions to build,
 | 
			
		||||
        echo test, and deploy your project.
 | 
			
		||||
| 
						 | 
				
			
			@ -12,8 +12,8 @@
 | 
			
		|||
 | 
			
		||||
namespace onnx_mlir {
 | 
			
		||||
 | 
			
		||||
void replaceAll(std::string &str, const std::string &from,
 | 
			
		||||
                const std::string &to) {
 | 
			
		||||
void replaceAll(
 | 
			
		||||
    std::string &str, const std::string &from, const std::string &to) {
 | 
			
		||||
  if (from.empty())
 | 
			
		||||
    return;
 | 
			
		||||
  size_t start_pos = 0;
 | 
			
		||||
| 
						 | 
				
			
			@ -121,7 +121,6 @@ void InitializedTensorMapping::AddMapping(
 | 
			
		|||
  nameToInitializedTensor.emplace(name, tensor);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
bool InitializedTensorMapping::ContainKey(std::string name) {
 | 
			
		||||
  return nameToInitializedTensor.count(name) != 0;
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -132,8 +131,8 @@ mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor(
 | 
			
		|||
  onnx::TensorProto initializer = GetInitializedTensor(name);
 | 
			
		||||
 | 
			
		||||
  // Tensor dimensions.
 | 
			
		||||
  llvm::ArrayRef<int64_t> tensorDims(initializer.dims().data(),
 | 
			
		||||
      initializer.dims().size());
 | 
			
		||||
  llvm::ArrayRef<int64_t> tensorDims(
 | 
			
		||||
      initializer.dims().data(), initializer.dims().size());
 | 
			
		||||
 | 
			
		||||
  // Emit ConstantOp and record the mapping between the input and
 | 
			
		||||
  // the constant value.
 | 
			
		||||
| 
						 | 
				
			
			@ -143,8 +142,7 @@ mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor(
 | 
			
		|||
  mlir::ShapedType tensorType;
 | 
			
		||||
  switch (initializer.data_type()) {
 | 
			
		||||
  case (onnx::TensorProto::FLOAT): {
 | 
			
		||||
      const auto& arrayAttrInitializer =
 | 
			
		||||
          CreateArrayAttribute<float>(initializer);
 | 
			
		||||
    const auto &arrayAttrInitializer = CreateArrayAttribute<float>(initializer);
 | 
			
		||||
    elementType = builder.getF32Type();
 | 
			
		||||
    tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
 | 
			
		||||
    constantDenseAttribute = mlir::DenseElementsAttr::get(
 | 
			
		||||
| 
						 | 
				
			
			@ -152,7 +150,7 @@ mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor(
 | 
			
		|||
    break;
 | 
			
		||||
  }
 | 
			
		||||
  case (onnx::TensorProto::INT32): {
 | 
			
		||||
      const auto& arrayAttrInitializer =
 | 
			
		||||
    const auto &arrayAttrInitializer =
 | 
			
		||||
        CreateArrayAttribute<int32_t>(initializer);
 | 
			
		||||
    elementType = builder.getIntegerType(32);
 | 
			
		||||
    tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
 | 
			
		||||
| 
						 | 
				
			
			@ -161,7 +159,7 @@ mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor(
 | 
			
		|||
    break;
 | 
			
		||||
  }
 | 
			
		||||
  case (onnx::TensorProto::INT64): {
 | 
			
		||||
      const auto& arrayAttrInitializer =
 | 
			
		||||
    const auto &arrayAttrInitializer =
 | 
			
		||||
        CreateArrayAttribute<int64_t>(initializer);
 | 
			
		||||
    elementType = builder.getIntegerType(64);
 | 
			
		||||
    tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,8 +20,8 @@
 | 
			
		|||
#include "mlir/IR/Builders.h"
 | 
			
		||||
#include "mlir/IR/Function.h"
 | 
			
		||||
#include "mlir/IR/Location.h"
 | 
			
		||||
#include "mlir/IR/Matchers.h"
 | 
			
		||||
#include "mlir/IR/MLIRContext.h"
 | 
			
		||||
#include "mlir/IR/Matchers.h"
 | 
			
		||||
#include "mlir/IR/Module.h"
 | 
			
		||||
#include "mlir/IR/PatternMatch.h"
 | 
			
		||||
#include "mlir/IR/StandardTypes.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -37,11 +37,12 @@
 | 
			
		|||
#endif
 | 
			
		||||
 | 
			
		||||
#include "onnx/onnx_pb.h"
 | 
			
		||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
 | 
			
		||||
 | 
			
		||||
namespace onnx_mlir {
 | 
			
		||||
 | 
			
		||||
void replaceAll(std::string &str, const std::string &from,
 | 
			
		||||
                const std::string &to);
 | 
			
		||||
void replaceAll(
 | 
			
		||||
    std::string &str, const std::string &from, const std::string &to);
 | 
			
		||||
 | 
			
		||||
std::string legalize_name(std::string name);
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -86,13 +87,13 @@ struct InitializedTensorMapping {
 | 
			
		|||
  // This will allow the propagation of shape information passed in as an
 | 
			
		||||
  // argument to operations such as Reshape and will enable other
 | 
			
		||||
  // optimizations such as constant folding.
 | 
			
		||||
  mlir::Value EmitInitializerForInputTensor(mlir::Location loc,
 | 
			
		||||
  	  mlir::OpBuilder &builder, std::string name);
 | 
			
		||||
  mlir::Value EmitInitializerForInputTensor(
 | 
			
		||||
      mlir::Location loc, mlir::OpBuilder &builder, std::string name);
 | 
			
		||||
 | 
			
		||||
  // Get initialized tensor.
 | 
			
		||||
  onnx::TensorProto& GetInitializedTensor(std::string name) {
 | 
			
		||||
    assert(nameToInitializedTensor.find(name) !=
 | 
			
		||||
               nameToInitializedTensor.end() &&
 | 
			
		||||
  onnx::TensorProto &GetInitializedTensor(std::string name) {
 | 
			
		||||
    assert(
 | 
			
		||||
        nameToInitializedTensor.find(name) != nameToInitializedTensor.end() &&
 | 
			
		||||
        "Tensor initializer not found");
 | 
			
		||||
    return nameToInitializedTensor.at(name);
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -127,8 +127,8 @@ private:
 | 
			
		|||
   * @param input onnx input tensor ValueInfoProto.
 | 
			
		||||
   * @param symbol mlir input argument.
 | 
			
		||||
   */
 | 
			
		||||
  void ImportInputTensorSymbol(const onnx::ValueInfoProto &input,
 | 
			
		||||
                               mlir::Value symbol) {
 | 
			
		||||
  void ImportInputTensorSymbol(
 | 
			
		||||
      const onnx::ValueInfoProto &input, mlir::Value symbol) {
 | 
			
		||||
    auto input_tensor_legalized_name = legalize_name(input.name());
 | 
			
		||||
    assert(!frontend_symbols_.ContainKey(input_tensor_legalized_name) &&
 | 
			
		||||
           "Found duplicate legalized input tensor names.");
 | 
			
		||||
| 
						 | 
				
			
			@ -136,8 +136,7 @@ private:
 | 
			
		|||
  }
 | 
			
		||||
 | 
			
		||||
  typedef bstd::variant<int64_t, std::vector<int64_t>, float,
 | 
			
		||||
                        std::vector<float>, std::string,
 | 
			
		||||
                        std::vector<std::string>>
 | 
			
		||||
      std::vector<float>, std::string, std::vector<std::string>>
 | 
			
		||||
      AttrValueType;
 | 
			
		||||
 | 
			
		||||
  struct ONNXAttrVisitor {
 | 
			
		||||
| 
						 | 
				
			
			@ -213,8 +212,8 @@ private:
 | 
			
		|||
    llvm_unreachable("Failed to convert attribute proto to name/value pair");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::vector<mlir::NamedAttribute>
 | 
			
		||||
  ImportNodeAttributes(const onnx::NodeProto &node) {
 | 
			
		||||
  std::vector<mlir::NamedAttribute> ImportNodeAttributes(
 | 
			
		||||
      const onnx::NodeProto &node) {
 | 
			
		||||
    std::vector<mlir::NamedAttribute> attributes;
 | 
			
		||||
    for (int i = 0; i < node.attribute_size(); ++i) {
 | 
			
		||||
      auto attr = node.attribute(i);
 | 
			
		||||
| 
						 | 
				
			
			@ -281,14 +280,13 @@ private:
 | 
			
		|||
    // TODO: Handle optional inputs.
 | 
			
		||||
    auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
 | 
			
		||||
    for (int i = 0; i < node.output().size(); i++) {
 | 
			
		||||
      frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
 | 
			
		||||
                                   *(op.getODSResults(i).begin()));
 | 
			
		||||
      frontend_symbols_.AddMapping(
 | 
			
		||||
          legalize_name(node.output()[i]), *(op.getODSResults(i).begin()));
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  void buildOperation(const onnx::NodeProto &node,
 | 
			
		||||
                      int expectedNumOperands = -1,
 | 
			
		||||
  void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1,
 | 
			
		||||
      int expectedNumResults = -1) {
 | 
			
		||||
    std::vector<mlir::Value> inputs;
 | 
			
		||||
    for (const auto &item : node.input())
 | 
			
		||||
| 
						 | 
				
			
			@ -299,8 +297,8 @@ private:
 | 
			
		|||
        inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
    buildOutputAndOperation<T>(node, inputs, expectedNumOperands,
 | 
			
		||||
        expectedNumResults);
 | 
			
		||||
    buildOutputAndOperation<T>(
 | 
			
		||||
        node, inputs, expectedNumOperands, expectedNumResults);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void ImportNodeReshape(onnx::NodeProto node, int nIn, int nOut) {
 | 
			
		||||
| 
						 | 
				
			
			@ -310,8 +308,7 @@ private:
 | 
			
		|||
      item = node.input()[i];
 | 
			
		||||
      // For the second argument, check if there exists an initializer.
 | 
			
		||||
      if (initializedTensors.ContainKey(legalize_name(item))) {
 | 
			
		||||
          inputs.push_back(
 | 
			
		||||
                initializedTensors.EmitInitializerForInputTensor(
 | 
			
		||||
        inputs.push_back(initializedTensors.EmitInitializerForInputTensor(
 | 
			
		||||
            UnknownLoc(), builder_, legalize_name(item)));
 | 
			
		||||
      } else if (frontend_symbols_.ContainKey(legalize_name(item))) {
 | 
			
		||||
        inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
 | 
			
		||||
| 
						 | 
				
			
			@ -372,7 +369,6 @@ private:
 | 
			
		|||
#if INCLUDE_ONNX_ML == 1
 | 
			
		||||
#include "src/Builder/MLOpBuildTable.inc"
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /*!
 | 
			
		||||
| 
						 | 
				
			
			@ -400,8 +396,8 @@ private:
 | 
			
		|||
    ret_vals.push_back(tensor_val);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void ImportGraph(const onnx::GraphProto &graph,
 | 
			
		||||
                   const std::string &name = "main_graph") {
 | 
			
		||||
  void ImportGraph(
 | 
			
		||||
      const onnx::GraphProto &graph, const std::string &name = "main_graph") {
 | 
			
		||||
    // Maintain a mapping between the parameter and its initializer.
 | 
			
		||||
    for (auto initializer : graph.initializer()) {
 | 
			
		||||
      auto name = initializer.name();
 | 
			
		||||
| 
						 | 
				
			
			@ -426,8 +422,7 @@ private:
 | 
			
		|||
 | 
			
		||||
    // Emit the entry point operation which specifies the number of user
 | 
			
		||||
    // inputs and outputs.
 | 
			
		||||
    auto entryPoint = mlir::ONNXEntryPointOp::create(
 | 
			
		||||
        UnknownLoc(), mainFunc,
 | 
			
		||||
    auto entryPoint = mlir::ONNXEntryPointOp::create(UnknownLoc(), mainFunc,
 | 
			
		||||
        /*numInputs=*/graph.input().size() - graph.initializer().size(),
 | 
			
		||||
        /*numOutputs=*/graph.output().size());
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -454,8 +449,8 @@ private:
 | 
			
		|||
 | 
			
		||||
    // Create a NoneTyped constant to be used for optional operation inputs
 | 
			
		||||
    // which are not used.
 | 
			
		||||
    none_ = builder_.create<mlir::ConstantOp>(UnknownLoc(),
 | 
			
		||||
        builder_.getUnitAttr());
 | 
			
		||||
    none_ =
 | 
			
		||||
        builder_.create<mlir::ConstantOp>(UnknownLoc(), builder_.getUnitAttr());
 | 
			
		||||
 | 
			
		||||
    // Import nodes in the graph.
 | 
			
		||||
    for (const auto &item : graph.node()) {
 | 
			
		||||
| 
						 | 
				
			
			@ -483,8 +478,7 @@ private:
 | 
			
		|||
namespace onnx_mlir {
 | 
			
		||||
 | 
			
		||||
void ImportFrontendModelFile(std::string model_fname,
 | 
			
		||||
                             mlir::MLIRContext &context,
 | 
			
		||||
                             mlir::OwningModuleRef &module) {
 | 
			
		||||
    mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
 | 
			
		||||
  onnx::ModelProto model;
 | 
			
		||||
  std::fstream input(model_fname, std::ios::in | std::ios::binary);
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -36,8 +36,7 @@ namespace onnx_mlir {
 | 
			
		|||
 *  @return MLIR::module generated for the ONNX model.
 | 
			
		||||
 */
 | 
			
		||||
void ImportFrontendModelFile(std::string model_fname,
 | 
			
		||||
                             mlir::MLIRContext &context,
 | 
			
		||||
                             mlir::OwningModuleRef &module);
 | 
			
		||||
    mlir::MLIRContext &context, mlir::OwningModuleRef &module);
 | 
			
		||||
 | 
			
		||||
/*!
 | 
			
		||||
 *  TODO: Import models into other extension dialects that cover the
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,4 +1,5 @@
 | 
			
		|||
//====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering --------===//
 | 
			
		||||
//====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering
 | 
			
		||||
//--------===//
 | 
			
		||||
//
 | 
			
		||||
// Copyright 2019 The IBM Research Authors.
 | 
			
		||||
//
 | 
			
		||||
| 
						 | 
				
			
			@ -21,10 +22,9 @@ class ONNXEntryPointLowering : public OpRewritePattern<ONNXEntryPointOp> {
 | 
			
		|||
public:
 | 
			
		||||
  using OpRewritePattern<ONNXEntryPointOp>::OpRewritePattern;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(ONNXEntryPointOp op,
 | 
			
		||||
                                     PatternRewriter &rewriter) const override {
 | 
			
		||||
    rewriter.replaceOpWithNewOp<KrnlEntryPointOp>(
 | 
			
		||||
        op,
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      ONNXEntryPointOp op, PatternRewriter &rewriter) const override {
 | 
			
		||||
    rewriter.replaceOpWithNewOp<KrnlEntryPointOp>(op,
 | 
			
		||||
        op.getAttrOfType<SymbolRefAttr>(
 | 
			
		||||
            ONNXEntryPointOp::getEntryPointFuncAttrName()),
 | 
			
		||||
        op.getAttrOfType<IntegerAttr>(ONNXEntryPointOp::getNumInputsAttrName()),
 | 
			
		||||
| 
						 | 
				
			
			@ -55,8 +55,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
 | 
			
		|||
 | 
			
		||||
  // We define the specific operations, or dialects, that are legal targets for
 | 
			
		||||
  // this lowering.
 | 
			
		||||
  target
 | 
			
		||||
      .addLegalDialect<KrnlOpsDialect, AffineDialect, StandardOpsDialect>();
 | 
			
		||||
  target.addLegalDialect<KrnlOpsDialect, AffineDialect, StandardOpsDialect>();
 | 
			
		||||
 | 
			
		||||
  // TODO: enable this once more ops are supported.
 | 
			
		||||
  // We also define the ONNX dialect as Illegal so that the conversion will fail
 | 
			
		||||
| 
						 | 
				
			
			@ -81,8 +80,8 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
 | 
			
		|||
  // Type conversion for function signatures.
 | 
			
		||||
  // Call MLIR FuncOp signature conversion when result type is
 | 
			
		||||
  // a ranked tensor.
 | 
			
		||||
  populateFuncOpTypeConversionPattern(patterns, &getContext(),
 | 
			
		||||
                                      tensor_to_memref_converter);
 | 
			
		||||
  populateFuncOpTypeConversionPattern(
 | 
			
		||||
      patterns, &getContext(), tensor_to_memref_converter);
 | 
			
		||||
 | 
			
		||||
  // Frontend operation lowering.
 | 
			
		||||
  // Math
 | 
			
		||||
| 
						 | 
				
			
			@ -119,5 +118,5 @@ std::unique_ptr<Pass> mlir::createLowerToKrnlPass() {
 | 
			
		|||
  return std::make_unique<FrontendToKrnlLoweringPass>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static PassRegistration<FrontendToKrnlLoweringPass>
 | 
			
		||||
    pass("lower-frontend", "Lower frontend ops to Krnl dialect.");
 | 
			
		||||
static PassRegistration<FrontendToKrnlLoweringPass> pass(
 | 
			
		||||
    "lower-frontend", "Lower frontend ops to Krnl dialect.");
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -499,8 +499,7 @@ template <typename ElementwiseUnaryOp>
 | 
			
		|||
struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
 | 
			
		||||
  ONNXElementwiseUnaryOpLowering(MLIRContext *ctx)
 | 
			
		||||
      : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {}
 | 
			
		||||
  LogicalResult
 | 
			
		||||
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
			
		||||
  LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
			
		||||
      ConversionPatternRewriter &rewriter) const final {
 | 
			
		||||
    // TODO: Check that the types are valid.
 | 
			
		||||
    // An element-wise unary operation must have all operands and the result of
 | 
			
		||||
| 
						 | 
				
			
			@ -566,8 +565,7 @@ template <typename ElementwiseVariadicOp>
 | 
			
		|||
struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
 | 
			
		||||
  ONNXElementwiseVariadicOpLowering(MLIRContext *ctx)
 | 
			
		||||
      : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {}
 | 
			
		||||
  LogicalResult
 | 
			
		||||
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
			
		||||
  LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
			
		||||
      ConversionPatternRewriter &rewriter) const final {
 | 
			
		||||
    // TODO: Check that the types are valid.
 | 
			
		||||
    // An element-wise variadic operation must have all operands and the result
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,4 +1,5 @@
 | 
			
		|||
//===----------------- Gemm.cpp - Lowering Gemm Op -------------------------===//
 | 
			
		||||
//===----------------- Gemm.cpp - Lowering Gemm Op
 | 
			
		||||
//-------------------------===//
 | 
			
		||||
//
 | 
			
		||||
// Copyright 2019 The IBM Research Authors.
 | 
			
		||||
//
 | 
			
		||||
| 
						 | 
				
			
			@ -17,8 +18,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
 | 
			
		|||
  ONNXGemmOpLowering(MLIRContext *ctx)
 | 
			
		||||
      : ConversionPattern(GemmOp::getOperationName(), 1, ctx) {}
 | 
			
		||||
 | 
			
		||||
  LogicalResult
 | 
			
		||||
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
			
		||||
  LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
			
		||||
      ConversionPatternRewriter &rewriter) const final {
 | 
			
		||||
    auto loc = op->getLoc();
 | 
			
		||||
    bool hasBias = !op->getOperand(2).getType().isa<NoneType>();
 | 
			
		||||
| 
						 | 
				
			
			@ -32,11 +32,9 @@ struct ONNXGemmOpLowering : public ConversionPattern {
 | 
			
		|||
 | 
			
		||||
    auto memRefType = convertToMemRefType(*op->result_type_begin());
 | 
			
		||||
 | 
			
		||||
    auto alphaAttr =
 | 
			
		||||
        FloatAttr::get(memRefType.getElementType(),
 | 
			
		||||
    auto alphaAttr = FloatAttr::get(memRefType.getElementType(),
 | 
			
		||||
        llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat());
 | 
			
		||||
    auto betaAttr =
 | 
			
		||||
        FloatAttr::get(memRefType.getElementType(),
 | 
			
		||||
    auto betaAttr = FloatAttr::get(memRefType.getElementType(),
 | 
			
		||||
        llvm::dyn_cast<GemmOp>(op).beta().convertToFloat());
 | 
			
		||||
    auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
 | 
			
		||||
    auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
 | 
			
		||||
| 
						 | 
				
			
			@ -101,8 +99,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
 | 
			
		|||
    optimizedReductionLoops.reserve(1);
 | 
			
		||||
    reductionLoops.push_back(originalLoops[2]);
 | 
			
		||||
    optimizedReductionLoops.push_back(optimizedLoops[2]);
 | 
			
		||||
    KrnlIterateOperandPack reductionPack(rewriter, reductionLoops,
 | 
			
		||||
                                         optimizedReductionLoops);
 | 
			
		||||
    KrnlIterateOperandPack reductionPack(
 | 
			
		||||
        rewriter, reductionLoops, optimizedReductionLoops);
 | 
			
		||||
    // Induction variable for the reduction dimension
 | 
			
		||||
    // Try to find and use a static value from A or B first.
 | 
			
		||||
    // If it failed then use a dynamic value.
 | 
			
		||||
| 
						 | 
				
			
			@ -167,8 +165,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
 | 
			
		|||
    auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
 | 
			
		||||
    auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
 | 
			
		||||
    if (hasBias) {
 | 
			
		||||
      auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C,
 | 
			
		||||
                                                broadcastedDimInfo);
 | 
			
		||||
      auto loopCIVs = getLoopIVsForBroadcasting(
 | 
			
		||||
          loc, rewriter, loopMNIVs, C, broadcastedDimInfo);
 | 
			
		||||
      auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
 | 
			
		||||
      auto betaC = rewriter.create<MulFOp>(loc, beta, loadedC);
 | 
			
		||||
      auto Y = rewriter.create<AddFOp>(loc, alphaAB, betaC);
 | 
			
		||||
| 
						 | 
				
			
			@ -214,7 +212,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
 | 
			
		|||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
 | 
			
		||||
                                       MLIRContext *ctx) {
 | 
			
		||||
void populateLoweringONNXGemmOpPattern(
 | 
			
		||||
    OwningRewritePatternList &patterns, MLIRContext *ctx) {
 | 
			
		||||
  patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,8 +16,7 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
 | 
			
		|||
  ONNXMatMulOpLowering(MLIRContext *ctx)
 | 
			
		||||
      : ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {}
 | 
			
		||||
 | 
			
		||||
  LogicalResult
 | 
			
		||||
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
			
		||||
  LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
			
		||||
      ConversionPatternRewriter &rewriter) const final {
 | 
			
		||||
    auto loc = op->getLoc();
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -119,8 +118,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
 | 
			
		|||
      // Define loops for batch dimensions.
 | 
			
		||||
      std::vector<Value> originalLoops;
 | 
			
		||||
      std::vector<Value> optimizedLoops;
 | 
			
		||||
      Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
 | 
			
		||||
            optimizedLoops, memRefShape.size());
 | 
			
		||||
      Block *optimizationBlock = defineLoops(
 | 
			
		||||
          rewriter, loc, originalLoops, optimizedLoops, memRefShape.size());
 | 
			
		||||
 | 
			
		||||
      // Outer KrnlIterateOp
 | 
			
		||||
      SmallVector<Value, 4> loopBatchIVs;
 | 
			
		||||
| 
						 | 
				
			
			@ -139,8 +138,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
 | 
			
		|||
          outerLoops.push_back(originalLoops[i]);
 | 
			
		||||
          optimizedOuterLoops.push_back(optimizedLoops[i]);
 | 
			
		||||
        }
 | 
			
		||||
        KrnlIterateOperandPack outerPack(rewriter, outerLoops,
 | 
			
		||||
                                         optimizedOuterLoops);
 | 
			
		||||
        KrnlIterateOperandPack outerPack(
 | 
			
		||||
            rewriter, outerLoops, optimizedOuterLoops);
 | 
			
		||||
        for (int i = 0; i < batchAxes.size(); ++i) {
 | 
			
		||||
          addDimensionToPack(rewriter, loc, outerPack, alloc, i);
 | 
			
		||||
        }
 | 
			
		||||
| 
						 | 
				
			
			@ -176,11 +175,11 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
 | 
			
		|||
          optimizedMatmulLoops.emplace_back(
 | 
			
		||||
              optimizedLoops[memRefShape.size() - i]);
 | 
			
		||||
        }
 | 
			
		||||
        KrnlIterateOperandPack matmulPack(rewriter, matmulLoops,
 | 
			
		||||
                                          optimizedMatmulLoops);
 | 
			
		||||
        KrnlIterateOperandPack matmulPack(
 | 
			
		||||
            rewriter, matmulLoops, optimizedMatmulLoops);
 | 
			
		||||
        for (int i = 2; i > 0; --i) {
 | 
			
		||||
          addDimensionToPack(rewriter, loc, matmulPack, alloc,
 | 
			
		||||
                             memRefShape.size() - i);
 | 
			
		||||
          addDimensionToPack(
 | 
			
		||||
              rewriter, loc, matmulPack, alloc, memRefShape.size() - i);
 | 
			
		||||
        }
 | 
			
		||||
        matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack);
 | 
			
		||||
      } else {
 | 
			
		||||
| 
						 | 
				
			
			@ -190,10 +189,10 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
 | 
			
		|||
        matmulLoops.emplace_back(originalLoops[memRefShape.size() - 1]);
 | 
			
		||||
        optimizedMatmulLoops.emplace_back(
 | 
			
		||||
            optimizedLoops[memRefShape.size() - 1]);
 | 
			
		||||
        KrnlIterateOperandPack matmulPack(rewriter, matmulLoops,
 | 
			
		||||
                                          optimizedMatmulLoops);
 | 
			
		||||
        addDimensionToPack(rewriter, loc, matmulPack, alloc,
 | 
			
		||||
                           memRefShape.size() - 1);
 | 
			
		||||
        KrnlIterateOperandPack matmulPack(
 | 
			
		||||
            rewriter, matmulLoops, optimizedMatmulLoops);
 | 
			
		||||
        addDimensionToPack(
 | 
			
		||||
            rewriter, loc, matmulPack, alloc, memRefShape.size() - 1);
 | 
			
		||||
        matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -230,8 +229,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
 | 
			
		|||
      std::vector<Value> optimizedReduceLoops;
 | 
			
		||||
      Block *optimizationReduceBlock =
 | 
			
		||||
          defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1);
 | 
			
		||||
      KrnlIterateOperandPack reducePack(rewriter, reduceLoops,
 | 
			
		||||
                                        optimizedReduceLoops);
 | 
			
		||||
      KrnlIterateOperandPack reducePack(
 | 
			
		||||
          rewriter, reduceLoops, optimizedReduceLoops);
 | 
			
		||||
      addDimensionToPack(rewriter, loc, reducePack, A, AShape.size() - 1);
 | 
			
		||||
      auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack);
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -292,8 +291,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
 | 
			
		|||
      std::vector<Value> optimizedReduceLoops;
 | 
			
		||||
      Block *optimizationReduceBlock =
 | 
			
		||||
          defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1);
 | 
			
		||||
      KrnlIterateOperandPack reducePack(rewriter, reduceLoops,
 | 
			
		||||
                                        optimizedReduceLoops);
 | 
			
		||||
      KrnlIterateOperandPack reducePack(
 | 
			
		||||
          rewriter, reduceLoops, optimizedReduceLoops);
 | 
			
		||||
      addDimensionToPack(rewriter, loc, reducePack, A, 0);
 | 
			
		||||
      auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack);
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -102,8 +102,7 @@ struct ONNXReductionOpLowering : public ConversionPattern {
 | 
			
		|||
  ONNXReductionOpLowering(MLIRContext *ctx)
 | 
			
		||||
      : ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {}
 | 
			
		||||
 | 
			
		||||
  LogicalResult
 | 
			
		||||
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
			
		||||
  LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
			
		||||
      ConversionPatternRewriter &rewriter) const final {
 | 
			
		||||
    /*
 | 
			
		||||
     * Condition: reduction function must be associative and commutative.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,4 +1,5 @@
 | 
			
		|||
//===--------------- Conv.cpp - Lowering Convolution Op --------------------===//
 | 
			
		||||
//===--------------- Conv.cpp - Lowering Convolution Op
 | 
			
		||||
//--------------------===//
 | 
			
		||||
//
 | 
			
		||||
// Copyright 2019 The IBM Research Authors.
 | 
			
		||||
//
 | 
			
		||||
| 
						 | 
				
			
			@ -175,14 +176,12 @@ struct ONNXConvOpLowering : public ConversionPattern {
 | 
			
		|||
 | 
			
		||||
        // Emit the bias, if needed.
 | 
			
		||||
        if (hasBias) {
 | 
			
		||||
          auto loadResult =
 | 
			
		||||
              rewriter.create<LoadOp>(loc, alloc, resultIndices);
 | 
			
		||||
          auto loadResult = rewriter.create<LoadOp>(loc, alloc, resultIndices);
 | 
			
		||||
          SmallVector<Value, 4> biasIndices;
 | 
			
		||||
          biasIndices.emplace_back(kernel);
 | 
			
		||||
          auto loadBias =
 | 
			
		||||
              rewriter.create<LoadOp>(loc, biasOperand, kernel);
 | 
			
		||||
          auto resultWithBias = rewriter.create<MulFOp>(
 | 
			
		||||
            loc, loadResult, loadBias);
 | 
			
		||||
          auto loadBias = rewriter.create<LoadOp>(loc, biasOperand, kernel);
 | 
			
		||||
          auto resultWithBias =
 | 
			
		||||
              rewriter.create<MulFOp>(loc, loadResult, loadBias);
 | 
			
		||||
          // Store initializer value into output location.
 | 
			
		||||
          rewriter.create<StoreOp>(loc, resultWithBias, alloc, resultIndices);
 | 
			
		||||
        }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -459,7 +459,8 @@ struct ONNXPoolOpLowering : public ConversionPattern {
 | 
			
		|||
        poolDimMap = AffineMap::get(1, 5, dimExpr, rewriter.getContext());
 | 
			
		||||
 | 
			
		||||
        // poolStartMap and poolEndMap
 | 
			
		||||
        poolStartMap = AffineMap::get(1, 5, {start1, start2}, rewriter.getContext());
 | 
			
		||||
        poolStartMap =
 | 
			
		||||
            AffineMap::get(1, 5, {start1, start2}, rewriter.getContext());
 | 
			
		||||
        poolEndMap = AffineMap::get(1, 5, {end1, end2}, rewriter.getContext());
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -36,9 +36,7 @@ MemRefType convertToMemRefType(Type type) {
 | 
			
		|||
 | 
			
		||||
/// Insert an allocation and deallocation for the given MemRefType.
 | 
			
		||||
Value insertAllocAndDealloc(MemRefType type, Location loc,
 | 
			
		||||
                                   PatternRewriter &rewriter,
 | 
			
		||||
                                   bool insertDealloc,
 | 
			
		||||
                                   ArrayRef<Value> operands) {
 | 
			
		||||
    PatternRewriter &rewriter, bool insertDealloc, ArrayRef<Value> operands) {
 | 
			
		||||
  // Put together alloc operands for any dynamic dimensions of the memref.
 | 
			
		||||
  AllocOp alloc;
 | 
			
		||||
  if (!operands.empty()) {
 | 
			
		||||
| 
						 | 
				
			
			@ -64,10 +62,10 @@ Value insertAllocAndDealloc(MemRefType type, Location loc,
 | 
			
		|||
          auto operandDim =
 | 
			
		||||
              rewriter.create<DimOp>(loc, operands[i], operandDimIdx);
 | 
			
		||||
          if (maxDim) {
 | 
			
		||||
            auto maxCondition = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt,
 | 
			
		||||
                                                        operandDim, maxDim);
 | 
			
		||||
            maxDim = rewriter.create<SelectOp>(loc, maxCondition, operandDim,
 | 
			
		||||
                                               maxDim);
 | 
			
		||||
            auto maxCondition = rewriter.create<CmpIOp>(
 | 
			
		||||
                loc, CmpIPredicate::sgt, operandDim, maxDim);
 | 
			
		||||
            maxDim = rewriter.create<SelectOp>(
 | 
			
		||||
                loc, maxCondition, operandDim, maxDim);
 | 
			
		||||
          } else {
 | 
			
		||||
            maxDim = operandDim;
 | 
			
		||||
          }
 | 
			
		||||
| 
						 | 
				
			
			@ -122,8 +120,8 @@ bool checkInsertDealloc(Operation *currentOp, int resultIndex) {
 | 
			
		|||
// Create a mapping from result type's dimensions to input type's dimensions,
 | 
			
		||||
// given that the result type is the result of a reduction op over the input
 | 
			
		||||
// type.
 | 
			
		||||
std::map<int64_t, int64_t>
 | 
			
		||||
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
 | 
			
		||||
std::map<int64_t, int64_t> getReductionMapping(
 | 
			
		||||
    MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
 | 
			
		||||
  std::map<int64_t, int64_t> OutInDimMap;
 | 
			
		||||
  int64_t rank = inputTy.getRank();
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -152,9 +150,8 @@ getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
 | 
			
		|||
 | 
			
		||||
// Add bounds associated with the op operand to the KRNL iteration pack.
 | 
			
		||||
// Dynamic dimenions are supported.
 | 
			
		||||
void addDimensionToPack(ConversionPatternRewriter &rewriter,
 | 
			
		||||
                               Location loc, KrnlIterateOperandPack &pack,
 | 
			
		||||
                               Value operand, int index) {
 | 
			
		||||
void addDimensionToPack(ConversionPatternRewriter &rewriter, Location loc,
 | 
			
		||||
    KrnlIterateOperandPack &pack, Value operand, int index) {
 | 
			
		||||
  auto shape = operand.getType().cast<MemRefType>().getShape();
 | 
			
		||||
  if (shape[index] < 0) {
 | 
			
		||||
    pack.pushConstantBound(0);
 | 
			
		||||
| 
						 | 
				
			
			@ -168,10 +165,9 @@ void addDimensionToPack(ConversionPatternRewriter &rewriter,
 | 
			
		|||
 | 
			
		||||
// Function that defines the KRNL dialect loops and their respective
 | 
			
		||||
// optimized version.
 | 
			
		||||
KrnlOptimizeLoopsOp
 | 
			
		||||
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
 | 
			
		||||
                   std::vector<Value> &loops,
 | 
			
		||||
                   std::vector<Value> &optimizedLoops, int64_t numLoops) {
 | 
			
		||||
KrnlOptimizeLoopsOp emitOptimizedLoops(ConversionPatternRewriter &rewriter,
 | 
			
		||||
    Location loc, std::vector<Value> &loops, std::vector<Value> &optimizedLoops,
 | 
			
		||||
    int64_t numLoops) {
 | 
			
		||||
  // Define loops.
 | 
			
		||||
  auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops);
 | 
			
		||||
  loops.reserve(numLoops);
 | 
			
		||||
| 
						 | 
				
			
			@ -190,8 +186,7 @@ emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
 | 
			
		|||
// Function that emits the loops and their optimized version.
 | 
			
		||||
// The function returns a reference to the inner optimization block.
 | 
			
		||||
Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
 | 
			
		||||
                          std::vector<Value> &loops,
 | 
			
		||||
                          std::vector<Value> &optimizedLoops,
 | 
			
		||||
    std::vector<Value> &loops, std::vector<Value> &optimizedLoops,
 | 
			
		||||
    int64_t numLoops) {
 | 
			
		||||
  KrnlOptimizeLoopsOp optimizedLoopsOp =
 | 
			
		||||
      emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops);
 | 
			
		||||
| 
						 | 
				
			
			@ -201,10 +196,9 @@ Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
 | 
			
		|||
// Function which emits a basic set of loops and optimized loops
 | 
			
		||||
// for a given operation argument. A reference to the loop optimization
 | 
			
		||||
// block is returned in the last argument of the function.
 | 
			
		||||
void emitKrnlLoopsAndIterationForOperand(
 | 
			
		||||
    ConversionPatternRewriter &rewriter, Location loc, Value operand,
 | 
			
		||||
    std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
 | 
			
		||||
    KrnlIterateOp &iterateOp) {
 | 
			
		||||
void emitKrnlLoopsAndIterationForOperand(ConversionPatternRewriter &rewriter,
 | 
			
		||||
    Location loc, Value operand, std::vector<Value> &originalLoops,
 | 
			
		||||
    KrnlOptimizeLoopsOp &optimizedLoopsOp, KrnlIterateOp &iterateOp) {
 | 
			
		||||
  // Operand shape.
 | 
			
		||||
  auto shape = operand.getType().cast<MemRefType>().getShape();
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -240,9 +234,9 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
 | 
			
		|||
 | 
			
		||||
// Get run-time dimension information for unknown dimensions used for
 | 
			
		||||
// broadcasting.
 | 
			
		||||
std::map<int, std::map<int, Value>>
 | 
			
		||||
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
 | 
			
		||||
                      MemRefType memRefType, ArrayRef<Value> operands) {
 | 
			
		||||
std::map<int, std::map<int, Value>> getBroadcastedDimInfo(Location loc,
 | 
			
		||||
    ConversionPatternRewriter &rewriter, MemRefType memRefType,
 | 
			
		||||
    ArrayRef<Value> operands) {
 | 
			
		||||
  auto memRefShape = memRefType.getShape();
 | 
			
		||||
  int64_t rank = memRefShape.size();
 | 
			
		||||
  // For unknown dimensions, we need to get dimension values at runtime in
 | 
			
		||||
| 
						 | 
				
			
			@ -286,9 +280,8 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
 | 
			
		|||
 | 
			
		||||
// Extract induction variables that are used for broadcasting values of a
 | 
			
		||||
// given operand.
 | 
			
		||||
std::vector<Value>
 | 
			
		||||
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
 | 
			
		||||
                          ArrayRef<Value> loopIVs, Value operand,
 | 
			
		||||
std::vector<Value> getLoopIVsForBroadcasting(Location loc,
 | 
			
		||||
    ConversionPatternRewriter &rewriter, ArrayRef<Value> loopIVs, Value operand,
 | 
			
		||||
    std::map<int, Value> broadcastedDims) {
 | 
			
		||||
  // `operand` must has a ranked type. This should have been checked by the
 | 
			
		||||
  // shape inference pass.
 | 
			
		||||
| 
						 | 
				
			
			@ -310,8 +303,8 @@ getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
 | 
			
		|||
      // If its value is 1, it is broadcasted dimension.
 | 
			
		||||
      // Otherwise, non-broadcasted dimension.
 | 
			
		||||
      auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
 | 
			
		||||
      auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], zero,
 | 
			
		||||
                                           loopIVs[loopIdx]);
 | 
			
		||||
      auto idx = rewriter.create<SelectOp>(
 | 
			
		||||
          loc, broadcastedDims[dimIdx], zero, loopIVs[loopIdx]);
 | 
			
		||||
      newLoopIVs.insert(newLoopIVs.begin(), idx);
 | 
			
		||||
    } else {
 | 
			
		||||
      // Non-broadcasted dimension
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -30,7 +30,7 @@ struct ONNXConcatOpLowering : public ConversionPattern {
 | 
			
		|||
    auto memRefType = convertToMemRefType(*op->result_type_begin());
 | 
			
		||||
    auto resultShape = memRefType.getShape();
 | 
			
		||||
    auto rank = resultShape.size();
 | 
			
		||||
    assert((axis >=0 && axis < rank) && "Concat axis out of bounds");
 | 
			
		||||
    assert((axis >= 0 && axis < rank) && "Concat axis out of bounds");
 | 
			
		||||
 | 
			
		||||
    if (hasAllConstantDimensions(memRefType))
 | 
			
		||||
      alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -34,12 +34,11 @@ struct ONNXConstantOpLowering : public ConversionPattern {
 | 
			
		|||
    // Shape based computations.
 | 
			
		||||
    auto shape = memRefType.getShape();
 | 
			
		||||
    int64_t numElements = 1;
 | 
			
		||||
    for (int i=0; i<shape.size(); ++i)
 | 
			
		||||
    for (int i = 0; i < shape.size(); ++i)
 | 
			
		||||
      numElements *= shape[i];
 | 
			
		||||
 | 
			
		||||
    // Emit the constant global in Krnl dialect.
 | 
			
		||||
    auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc,
 | 
			
		||||
        memRefType,
 | 
			
		||||
    auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc, memRefType,
 | 
			
		||||
        rewriter.getI64ArrayAttr(shape),
 | 
			
		||||
        rewriter.getStringAttr("constant_" + std::to_string(constantID)),
 | 
			
		||||
        constantOp.value().getValue());
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -24,14 +24,14 @@ enum Kinds {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
class LoopType : public mlir::Type::TypeBase<LoopType, mlir::Type> {
 | 
			
		||||
 public:
 | 
			
		||||
public:
 | 
			
		||||
  using Base::Base;
 | 
			
		||||
 | 
			
		||||
  // Support type inquiry through isa, cast and dyn_cast.
 | 
			
		||||
  static bool kindof(unsigned kind) { return kind == KrnlTypes::Loop; }
 | 
			
		||||
 | 
			
		||||
  // Get a unique instance of Loop type.
 | 
			
		||||
  static LoopType get(mlir::MLIRContext* context) {
 | 
			
		||||
  static LoopType get(mlir::MLIRContext *context) {
 | 
			
		||||
    return Base::get(context, KrnlTypes::Loop);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -39,7 +39,6 @@ MLONNXOpsDialect::MLONNXOpsDialect(mlir::MLIRContext *ctx)
 | 
			
		|||
      >();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// TableGen'd op method definitions
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,14 +19,14 @@
 | 
			
		|||
#include "mlir/IR/OpDefinition.h"
 | 
			
		||||
#include "mlir/IR/StandardTypes.h"
 | 
			
		||||
 | 
			
		||||
#include "src/Interface/ShapeInferenceInterface.hpp"
 | 
			
		||||
#include "src/Interface/PromotableConstOperandsOpInterface.hpp"
 | 
			
		||||
#include "src/Interface/ShapeInferenceInterface.hpp"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
 | 
			
		||||
class MLONNXOpsDialect : public Dialect {
 | 
			
		||||
 public:
 | 
			
		||||
  MLONNXOpsDialect(MLIRContext* context);
 | 
			
		||||
public:
 | 
			
		||||
  MLONNXOpsDialect(MLIRContext *context);
 | 
			
		||||
 | 
			
		||||
  /// Provide a utility accessor to the dialect namespace. This is used by
 | 
			
		||||
  /// several utilities for casting between dialects.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,14 +19,14 @@
 | 
			
		|||
#include "mlir/IR/OpDefinition.h"
 | 
			
		||||
#include "mlir/IR/StandardTypes.h"
 | 
			
		||||
 | 
			
		||||
#include "src/Interface/ShapeInferenceInterface.hpp"
 | 
			
		||||
#include "src/Interface/PromotableConstOperandsOpInterface.hpp"
 | 
			
		||||
#include "src/Interface/ShapeInferenceInterface.hpp"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
 | 
			
		||||
class ONNXOpsDialect : public Dialect {
 | 
			
		||||
 public:
 | 
			
		||||
  ONNXOpsDialect(MLIRContext* context);
 | 
			
		||||
public:
 | 
			
		||||
  ONNXOpsDialect(MLIRContext *context);
 | 
			
		||||
 | 
			
		||||
  /// Provide a utility accessor to the dialect namespace. This is used by
 | 
			
		||||
  /// several utilities for casting between dialects.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -10,7 +10,6 @@
 | 
			
		|||
//
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#include "src/Interface/PromotableConstOperandsOpInterface.hpp"
 | 
			
		||||
 | 
			
		||||
using namespace mlir;
 | 
			
		||||
| 
						 | 
				
			
			@ -20,4 +19,3 @@ using namespace mlir;
 | 
			
		|||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
#include "src/Interface/PromotableConstOperandsOpInterface.cpp.inc"
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -46,10 +46,10 @@ void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
 | 
			
		|||
void EmitLLVMBitCode(
 | 
			
		||||
    const mlir::OwningModuleRef &module, string outputFilename) {
 | 
			
		||||
  error_code error;
 | 
			
		||||
  llvm::raw_fd_ostream moduleBitcodeStream(outputFilename, error,
 | 
			
		||||
                                           llvm::sys::fs::F_None);
 | 
			
		||||
  llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module),
 | 
			
		||||
                           moduleBitcodeStream);
 | 
			
		||||
  llvm::raw_fd_ostream moduleBitcodeStream(
 | 
			
		||||
      outputFilename, error, llvm::sys::fs::F_None);
 | 
			
		||||
  llvm::WriteBitcodeToFile(
 | 
			
		||||
      *mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream);
 | 
			
		||||
  moduleBitcodeStream.flush();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -99,7 +99,6 @@ void processInputFile(string inputFilename, EmissionTargetType emissionTarget,
 | 
			
		|||
  assert(inputIsONNX != inputIsMLIR &&
 | 
			
		||||
         "Either ONNX model or MLIR file needs to be provided.");
 | 
			
		||||
 | 
			
		||||
  
 | 
			
		||||
  if (inputIsONNX) {
 | 
			
		||||
    ImportFrontendModelFile(inputFilename, context, module);
 | 
			
		||||
  } else {
 | 
			
		||||
| 
						 | 
				
			
			@ -119,7 +118,7 @@ void outputCode(
 | 
			
		|||
  module->dump();
 | 
			
		||||
  fflush(stderr);
 | 
			
		||||
  // set modified stderr as original stderr
 | 
			
		||||
  _dup2(stderrOrigin, _fileno( stderr ));
 | 
			
		||||
  _dup2(stderrOrigin, _fileno(stderr));
 | 
			
		||||
#else
 | 
			
		||||
  if (fork() == 0) {
 | 
			
		||||
    freopen(tempFilename.c_str(), "w", stderr);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -28,9 +28,9 @@
 | 
			
		|||
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
 | 
			
		||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
 | 
			
		||||
#include "mlir/ExecutionEngine/OptUtils.h"
 | 
			
		||||
#include "mlir/InitAllDialects.h"
 | 
			
		||||
#include "mlir/IR/MLIRContext.h"
 | 
			
		||||
#include "mlir/IR/Module.h"
 | 
			
		||||
#include "mlir/InitAllDialects.h"
 | 
			
		||||
#include "mlir/Parser.h"
 | 
			
		||||
#include "mlir/Pass/Pass.h"
 | 
			
		||||
#include "mlir/Pass/PassManager.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -66,8 +66,7 @@ void processInputFile(std::string inputFilename,
 | 
			
		|||
    mlir::OwningModuleRef &module);
 | 
			
		||||
 | 
			
		||||
void outputCode(
 | 
			
		||||
    mlir::OwningModuleRef &module, std::string filename,
 | 
			
		||||
    std::string extension);
 | 
			
		||||
    mlir::OwningModuleRef &module, std::string filename, std::string extension);
 | 
			
		||||
 | 
			
		||||
void emitOutputFiles(std::string outputBaseName,
 | 
			
		||||
    EmissionTargetType emissionTarget, mlir::MLIRContext &context,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -33,8 +33,8 @@ DynMemRef *getDynMemRef(OrderedDynMemRefDict *tensorDict, int idx) {
 | 
			
		|||
  return tensorDict->tensorDict[tensorDict->orderedNames[idx]];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx,
 | 
			
		||||
                  DynMemRef *tensor) {
 | 
			
		||||
void setDynMemRef(
 | 
			
		||||
    OrderedDynMemRefDict *tensorDict, int idx, DynMemRef *tensor) {
 | 
			
		||||
  if (tensorDict->orderedNames.size() <= idx)
 | 
			
		||||
    tensorDict->orderedNames.resize(idx + 1);
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -39,7 +39,6 @@ typedef struct _OrderedDynMemRefDict OrderedDynMemRefDict;
 | 
			
		|||
extern "C" {
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// Get number of dynamic memrefs in OrderedDynMemRefDict dict.
 | 
			
		||||
int numDynMemRefs(OrderedDynMemRefDict *dict);
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -53,8 +52,8 @@ DynMemRef *createDynMemRef(int rank);
 | 
			
		|||
DynMemRef *getDynMemRef(OrderedDynMemRefDict *orderedDict, int i);
 | 
			
		||||
 | 
			
		||||
// Set the i-th dynmemref in orderedDict to be dynMemRef.
 | 
			
		||||
void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx,
 | 
			
		||||
                  DynMemRef *dynMemRef);
 | 
			
		||||
void setDynMemRef(
 | 
			
		||||
    OrderedDynMemRefDict *tensorDict, int idx, DynMemRef *dynMemRef);
 | 
			
		||||
 | 
			
		||||
// Get data pointer from dynMemRef.
 | 
			
		||||
void *getData(DynMemRef *dynMemRef);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,14 +1,14 @@
 | 
			
		|||
#include "Runtime.hpp"
 | 
			
		||||
 | 
			
		||||
ExecutionSession::ExecutionSession(std::string sharedLibPath,
 | 
			
		||||
                                   std::string entryPointName) {
 | 
			
		||||
ExecutionSession::ExecutionSession(
 | 
			
		||||
    std::string sharedLibPath, std::string entryPointName) {
 | 
			
		||||
  _sharedLibraryHandle = dlopen(sharedLibPath.c_str(), RTLD_LAZY);
 | 
			
		||||
  _entryPointFunc =
 | 
			
		||||
      (entryPointFuncType)dlsym(_sharedLibraryHandle, entryPointName.c_str());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::vector<py::array>
 | 
			
		||||
ExecutionSession::run(std::vector<py::array> inputsPyArray) {
 | 
			
		||||
std::vector<py::array> ExecutionSession::run(
 | 
			
		||||
    std::vector<py::array> inputsPyArray) {
 | 
			
		||||
  assert(_entryPointFunc && "entry point not loaded");
 | 
			
		||||
  auto *wrappedInput = createOrderedDynMemRefDict();
 | 
			
		||||
  int inputIdx = 0;
 | 
			
		||||
| 
						 | 
				
			
			@ -40,8 +40,8 @@ ExecutionSession::run(std::vector<py::array> inputsPyArray) {
 | 
			
		|||
  auto *wrappedOutput = _entryPointFunc(wrappedInput);
 | 
			
		||||
  for (int i = 0; i < numDynMemRefs(wrappedOutput); i++) {
 | 
			
		||||
    auto *dynMemRef = getDynMemRef(wrappedOutput, i);
 | 
			
		||||
    auto shape = std::vector<int64_t>(dynMemRef->sizes,
 | 
			
		||||
                                      dynMemRef->sizes + dynMemRef->rank);
 | 
			
		||||
    auto shape = std::vector<int64_t>(
 | 
			
		||||
        dynMemRef->sizes, dynMemRef->sizes + dynMemRef->rank);
 | 
			
		||||
    outputPyArrays.emplace_back(
 | 
			
		||||
        py::array(py::dtype("float32"), shape, dynMemRef->data));
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -144,9 +144,9 @@ public:
 | 
			
		|||
 | 
			
		||||
      assert(krnlGlobalOp.value().hasValue() &&
 | 
			
		||||
             "Krnl Global must always have a value");
 | 
			
		||||
      global = rewriter.create<LLVM::GlobalOp>(loc,
 | 
			
		||||
          llvmGlobalType, /*isConstant=*/true,
 | 
			
		||||
          LLVM::Linkage::Internal, name, krnlGlobalOp.value().getValue());
 | 
			
		||||
      global = rewriter.create<LLVM::GlobalOp>(loc, llvmGlobalType,
 | 
			
		||||
          /*isConstant=*/true, LLVM::Linkage::Internal, name,
 | 
			
		||||
          krnlGlobalOp.value().getValue());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Some frequently used types.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,8 +12,8 @@
 | 
			
		|||
#include "mlir/IR/Matchers.h"
 | 
			
		||||
#include "mlir/IR/PatternMatch.h"
 | 
			
		||||
 | 
			
		||||
#include <numeric>
 | 
			
		||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
 | 
			
		||||
#include <numeric>
 | 
			
		||||
 | 
			
		||||
using namespace mlir;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -25,22 +25,22 @@ namespace {
 | 
			
		|||
/// Register optimization patterns as "canonicalization" patterns
 | 
			
		||||
/// on the ONNXMatMultOp.
 | 
			
		||||
void ONNXAddOp::getCanonicalizationPatterns(
 | 
			
		||||
    OwningRewritePatternList& results, MLIRContext* context) {
 | 
			
		||||
    OwningRewritePatternList &results, MLIRContext *context) {
 | 
			
		||||
  results.insert<MulAddToGemmOptPattern>(context);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ONNXGemmOp::getCanonicalizationPatterns(
 | 
			
		||||
        OwningRewritePatternList& results, MLIRContext* context) {
 | 
			
		||||
    OwningRewritePatternList &results, MLIRContext *context) {
 | 
			
		||||
  results.insert<FuseGemmFollowedByAddition>(context);
 | 
			
		||||
}
 | 
			
		||||
/// on the ONNXIdentityOp.
 | 
			
		||||
void ONNXIdentityOp::getCanonicalizationPatterns(
 | 
			
		||||
    OwningRewritePatternList& results, MLIRContext* context) {
 | 
			
		||||
    OwningRewritePatternList &results, MLIRContext *context) {
 | 
			
		||||
  results.insert<IdentityEliminationPattern>(context);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
///on the ONNXPadConstantValueOp.
 | 
			
		||||
/// on the ONNXPadConstantValueOp.
 | 
			
		||||
void ONNXPadConstantValueOp::getCanonicalizationPatterns(
 | 
			
		||||
    OwningRewritePatternList& result, MLIRContext* context) {
 | 
			
		||||
    OwningRewritePatternList &result, MLIRContext *context) {
 | 
			
		||||
  result.insert<ConstantPadPattern>(context);
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -27,7 +27,8 @@ namespace {
 | 
			
		|||
/// Include the patterns defined in the Declarative Rewrite framework.
 | 
			
		||||
#include "src/Transform/ONNX/ONNXDecompose.inc"
 | 
			
		||||
 | 
			
		||||
struct DecomposeONNXToONNXPass : public PassWrapper<DecomposeONNXToONNXPass, FunctionPass> {
 | 
			
		||||
struct DecomposeONNXToONNXPass
 | 
			
		||||
    : public PassWrapper<DecomposeONNXToONNXPass, FunctionPass> {
 | 
			
		||||
  void runOnFunction() final;
 | 
			
		||||
};
 | 
			
		||||
} // end anonymous namespace.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -9,10 +9,10 @@
 | 
			
		|||
//
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
#include "mlir/IR/StandardTypes.h"
 | 
			
		||||
#include "mlir/Pass/Pass.h"
 | 
			
		||||
#include "llvm/ADT/SmallPtrSet.h"
 | 
			
		||||
#include "llvm/Support/raw_ostream.h"
 | 
			
		||||
#include "mlir/IR/StandardTypes.h"
 | 
			
		||||
 | 
			
		||||
#include "src/Interface/ShapeInferenceInterface.hpp"
 | 
			
		||||
#include "src/Pass/Passes.hpp"
 | 
			
		||||
| 
						 | 
				
			
			@ -25,7 +25,8 @@ namespace {
 | 
			
		|||
 *  candidate operations and propagating the shape information until the list
 | 
			
		||||
 *  of operations is empty [credit MLIR authors].
 | 
			
		||||
 */
 | 
			
		||||
class ShapeInferencePass : public mlir::PassWrapper<ShapeInferencePass, mlir::FunctionPass> {
 | 
			
		||||
class ShapeInferencePass
 | 
			
		||||
    : public mlir::PassWrapper<ShapeInferencePass, mlir::FunctionPass> {
 | 
			
		||||
public:
 | 
			
		||||
  void runOnFunction() override {
 | 
			
		||||
    auto f = getFunction();
 | 
			
		||||
| 
						 | 
				
			
			@ -63,8 +64,7 @@ public:
 | 
			
		|||
 | 
			
		||||
    if (auto terminator_op = f.getBody().back().getTerminator()) {
 | 
			
		||||
      auto results = terminator_op->getOperandTypes();
 | 
			
		||||
      f.setType(FunctionType::get(
 | 
			
		||||
          f.getType().getInputs(),
 | 
			
		||||
      f.setType(FunctionType::get(f.getType().getInputs(),
 | 
			
		||||
          std::vector<Type>(results.begin(), results.end()), f.getContext()));
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			@ -146,5 +146,5 @@ std::unique_ptr<mlir::Pass> mlir::createShapeInferencePass() {
 | 
			
		|||
  return std::make_unique<ShapeInferencePass>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static PassRegistration<ShapeInferencePass>
 | 
			
		||||
    pass("shape-inference", "Shape inference for frontend dialects.");
 | 
			
		||||
static PassRegistration<ShapeInferencePass> pass(
 | 
			
		||||
    "shape-inference", "Shape inference for frontend dialects.");
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										20
									
								
								src/main.cpp
								
								
								
								
							
							
						
						
									
										20
									
								
								src/main.cpp
								
								
								
								
							| 
						 | 
				
			
			@ -14,10 +14,10 @@ using namespace onnx_mlir;
 | 
			
		|||
int main(int argc, char *argv[]) {
 | 
			
		||||
  registerDialects();
 | 
			
		||||
 | 
			
		||||
  llvm::cl::OptionCategory OnnxMlirOptions("ONNX MLIR Options",
 | 
			
		||||
                                       "These are frontend options.");
 | 
			
		||||
  llvm::cl::opt<string> inputFilename(
 | 
			
		||||
      llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"),
 | 
			
		||||
  llvm::cl::OptionCategory OnnxMlirOptions(
 | 
			
		||||
      "ONNX MLIR Options", "These are frontend options.");
 | 
			
		||||
  llvm::cl::opt<string> inputFilename(llvm::cl::Positional,
 | 
			
		||||
      llvm::cl::desc("<input file>"), llvm::cl::init("-"),
 | 
			
		||||
      llvm::cl::cat(OnnxMlirOptions));
 | 
			
		||||
 | 
			
		||||
  llvm::cl::opt<EmissionTargetType> emissionTarget(
 | 
			
		||||
| 
						 | 
				
			
			@ -26,18 +26,18 @@ int main(int argc, char *argv[]) {
 | 
			
		|||
          clEnumVal(EmitONNXBasic,
 | 
			
		||||
              "Ingest ONNX and emit the basic ONNX operations without"
 | 
			
		||||
              "inferred shapes."),
 | 
			
		||||
          clEnumVal(EmitONNXIR,
 | 
			
		||||
                    "Ingest ONNX and emit corresponding ONNX dialect."),
 | 
			
		||||
          clEnumVal(EmitMLIR,
 | 
			
		||||
                    "Lower model to MLIR built-in transformation dialect."),
 | 
			
		||||
          clEnumVal(
 | 
			
		||||
              EmitONNXIR, "Ingest ONNX and emit corresponding ONNX dialect."),
 | 
			
		||||
          clEnumVal(
 | 
			
		||||
              EmitMLIR, "Lower model to MLIR built-in transformation dialect."),
 | 
			
		||||
          clEnumVal(EmitLLVMIR, "Lower model to LLVM IR (LLVM dialect)."),
 | 
			
		||||
          clEnumVal(EmitLLVMBC, "Lower model to LLVM IR and emit (to file) "
 | 
			
		||||
                                "LLVM bitcode for model.")),
 | 
			
		||||
      llvm::cl::init(EmitLLVMBC), llvm::cl::cat(OnnxMlirOptions));
 | 
			
		||||
 | 
			
		||||
  llvm::cl::HideUnrelatedOptions(OnnxMlirOptions);
 | 
			
		||||
  llvm::cl::ParseCommandLineOptions(argc, argv,
 | 
			
		||||
                                    "ONNX MLIR modular optimizer driver\n");
 | 
			
		||||
  llvm::cl::ParseCommandLineOptions(
 | 
			
		||||
      argc, argv, "ONNX MLIR modular optimizer driver\n");
 | 
			
		||||
 | 
			
		||||
  mlir::MLIRContext context;
 | 
			
		||||
  mlir::OwningModuleRef module;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue