[NFC] Set up clang-format Github Action (#119)
* Run clang-format on all source code. * Add Clang-Format Github Action. * Apply patch produced by Clang-Format Bot. * nit. Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
		
							parent
							
								
									24343177b8
								
							
						
					
					
						commit
						7f2bffb27d
					
				|  | @ -0,0 +1,40 @@ | ||||||
|  | # This is a basic workflow to help you get started with Actions | ||||||
|  | 
 | ||||||
|  | name: Clang-Format Bot | ||||||
|  | 
 | ||||||
|  | # Controls when the action will run. Triggers the workflow on push or pull request | ||||||
|  | # events but only for the master branch | ||||||
|  | on: | ||||||
|  |   push: | ||||||
|  |     branches: [ master ] | ||||||
|  |   pull_request: | ||||||
|  |     branches: [ master ] | ||||||
|  | 
 | ||||||
|  | # A workflow run is made up of one or more jobs that can run sequentially or in parallel | ||||||
|  | jobs: | ||||||
|  |   # This workflow contains a single job called "build" | ||||||
|  |   build: | ||||||
|  |     # The type of runner that the job will run on | ||||||
|  |     runs-on: ubuntu-latest | ||||||
|  | 
 | ||||||
|  |     # Steps represent a sequence of tasks that will be executed as part of the job | ||||||
|  |     steps: | ||||||
|  |     # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it | ||||||
|  |     - uses: actions/checkout@v2 | ||||||
|  |     - name: clang-format lint | ||||||
|  |       uses: DoozyX/clang-format-lint-action@v0.5 | ||||||
|  |       with: | ||||||
|  |         # Source folder to check formatting | ||||||
|  |         source: ./src | ||||||
|  |         # Version of clang-format | ||||||
|  |         clangFormatVersion: 9 # optional, default is 9 | ||||||
|  |      | ||||||
|  |     # Runs a single command using the runners shell | ||||||
|  |     - name: Run a one-line script | ||||||
|  |       run: echo Hello, world! | ||||||
|  | 
 | ||||||
|  |     # Runs a set of commands using the runners shell | ||||||
|  |     - name: Run a multi-line script | ||||||
|  |       run: | | ||||||
|  |         echo Add other actions to build, | ||||||
|  |         echo test, and deploy your project. | ||||||
|  | @ -12,8 +12,8 @@ | ||||||
| 
 | 
 | ||||||
| namespace onnx_mlir { | namespace onnx_mlir { | ||||||
| 
 | 
 | ||||||
| void replaceAll(std::string &str, const std::string &from, | void replaceAll( | ||||||
|                 const std::string &to) { |     std::string &str, const std::string &from, const std::string &to) { | ||||||
|   if (from.empty()) |   if (from.empty()) | ||||||
|     return; |     return; | ||||||
|   size_t start_pos = 0; |   size_t start_pos = 0; | ||||||
|  | @ -121,7 +121,6 @@ void InitializedTensorMapping::AddMapping( | ||||||
|   nameToInitializedTensor.emplace(name, tensor); |   nameToInitializedTensor.emplace(name, tensor); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| bool InitializedTensorMapping::ContainKey(std::string name) { | bool InitializedTensorMapping::ContainKey(std::string name) { | ||||||
|   return nameToInitializedTensor.count(name) != 0; |   return nameToInitializedTensor.count(name) != 0; | ||||||
| } | } | ||||||
|  | @ -132,8 +131,8 @@ mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor( | ||||||
|   onnx::TensorProto initializer = GetInitializedTensor(name); |   onnx::TensorProto initializer = GetInitializedTensor(name); | ||||||
| 
 | 
 | ||||||
|   // Tensor dimensions.
 |   // Tensor dimensions.
 | ||||||
|   llvm::ArrayRef<int64_t> tensorDims(initializer.dims().data(), |   llvm::ArrayRef<int64_t> tensorDims( | ||||||
|       initializer.dims().size()); |       initializer.dims().data(), initializer.dims().size()); | ||||||
| 
 | 
 | ||||||
|   // Emit ConstantOp and record the mapping between the input and
 |   // Emit ConstantOp and record the mapping between the input and
 | ||||||
|   // the constant value.
 |   // the constant value.
 | ||||||
|  | @ -142,33 +141,32 @@ mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor( | ||||||
|   mlir::Type elementType; |   mlir::Type elementType; | ||||||
|   mlir::ShapedType tensorType; |   mlir::ShapedType tensorType; | ||||||
|   switch (initializer.data_type()) { |   switch (initializer.data_type()) { | ||||||
|     case (onnx::TensorProto::FLOAT): { |   case (onnx::TensorProto::FLOAT): { | ||||||
|       const auto& arrayAttrInitializer = |     const auto &arrayAttrInitializer = CreateArrayAttribute<float>(initializer); | ||||||
|           CreateArrayAttribute<float>(initializer); |     elementType = builder.getF32Type(); | ||||||
|       elementType = builder.getF32Type(); |     tensorType = mlir::RankedTensorType::get(tensorDims, elementType); | ||||||
|       tensorType = mlir::RankedTensorType::get(tensorDims, elementType); |     constantDenseAttribute = mlir::DenseElementsAttr::get( | ||||||
|       constantDenseAttribute = mlir::DenseElementsAttr::get( |         tensorType, llvm::makeArrayRef(arrayAttrInitializer)); | ||||||
|           tensorType, llvm::makeArrayRef(arrayAttrInitializer)); |     break; | ||||||
|       break; |   } | ||||||
|     } |   case (onnx::TensorProto::INT32): { | ||||||
|     case (onnx::TensorProto::INT32): { |     const auto &arrayAttrInitializer = | ||||||
|       const auto& arrayAttrInitializer = |         CreateArrayAttribute<int32_t>(initializer); | ||||||
|           CreateArrayAttribute<int32_t>(initializer); |     elementType = builder.getIntegerType(32); | ||||||
|       elementType = builder.getIntegerType(32); |     tensorType = mlir::RankedTensorType::get(tensorDims, elementType); | ||||||
|       tensorType = mlir::RankedTensorType::get(tensorDims, elementType); |     constantDenseAttribute = mlir::DenseElementsAttr::get( | ||||||
|       constantDenseAttribute = mlir::DenseElementsAttr::get( |         tensorType, llvm::makeArrayRef(arrayAttrInitializer)); | ||||||
|           tensorType, llvm::makeArrayRef(arrayAttrInitializer)); |     break; | ||||||
|       break; |   } | ||||||
|     } |   case (onnx::TensorProto::INT64): { | ||||||
|     case (onnx::TensorProto::INT64): { |     const auto &arrayAttrInitializer = | ||||||
|       const auto& arrayAttrInitializer = |         CreateArrayAttribute<int64_t>(initializer); | ||||||
|           CreateArrayAttribute<int64_t>(initializer); |     elementType = builder.getIntegerType(64); | ||||||
|       elementType = builder.getIntegerType(64); |     tensorType = mlir::RankedTensorType::get(tensorDims, elementType); | ||||||
|       tensorType = mlir::RankedTensorType::get(tensorDims, elementType); |     constantDenseAttribute = mlir::DenseElementsAttr::get( | ||||||
|       constantDenseAttribute = mlir::DenseElementsAttr::get( |         tensorType, llvm::makeArrayRef(arrayAttrInitializer)); | ||||||
|           tensorType, llvm::makeArrayRef(arrayAttrInitializer)); |     break; | ||||||
|       break; |   } | ||||||
|     } |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   // Create ConstantOp for dense array.
 |   // Create ConstantOp for dense array.
 | ||||||
|  |  | ||||||
|  | @ -20,8 +20,8 @@ | ||||||
| #include "mlir/IR/Builders.h" | #include "mlir/IR/Builders.h" | ||||||
| #include "mlir/IR/Function.h" | #include "mlir/IR/Function.h" | ||||||
| #include "mlir/IR/Location.h" | #include "mlir/IR/Location.h" | ||||||
| #include "mlir/IR/Matchers.h" |  | ||||||
| #include "mlir/IR/MLIRContext.h" | #include "mlir/IR/MLIRContext.h" | ||||||
|  | #include "mlir/IR/Matchers.h" | ||||||
| #include "mlir/IR/Module.h" | #include "mlir/IR/Module.h" | ||||||
| #include "mlir/IR/PatternMatch.h" | #include "mlir/IR/PatternMatch.h" | ||||||
| #include "mlir/IR/StandardTypes.h" | #include "mlir/IR/StandardTypes.h" | ||||||
|  | @ -37,11 +37,12 @@ | ||||||
| #endif | #endif | ||||||
| 
 | 
 | ||||||
| #include "onnx/onnx_pb.h" | #include "onnx/onnx_pb.h" | ||||||
|  | #include "src/Dialect/ONNX/ONNXOps.hpp" | ||||||
| 
 | 
 | ||||||
| namespace onnx_mlir { | namespace onnx_mlir { | ||||||
| 
 | 
 | ||||||
| void replaceAll(std::string &str, const std::string &from, | void replaceAll( | ||||||
|                 const std::string &to); |     std::string &str, const std::string &from, const std::string &to); | ||||||
| 
 | 
 | ||||||
| std::string legalize_name(std::string name); | std::string legalize_name(std::string name); | ||||||
| 
 | 
 | ||||||
|  | @ -86,14 +87,14 @@ struct InitializedTensorMapping { | ||||||
|   // This will allow the propagation of shape information passed in as an
 |   // This will allow the propagation of shape information passed in as an
 | ||||||
|   // argument to operations such as Reshape and will enable other
 |   // argument to operations such as Reshape and will enable other
 | ||||||
|   // optimizations such as constant folding.
 |   // optimizations such as constant folding.
 | ||||||
|   mlir::Value EmitInitializerForInputTensor(mlir::Location loc, |   mlir::Value EmitInitializerForInputTensor( | ||||||
|   	  mlir::OpBuilder &builder, std::string name); |       mlir::Location loc, mlir::OpBuilder &builder, std::string name); | ||||||
| 
 | 
 | ||||||
|   // Get initialized tensor.
 |   // Get initialized tensor.
 | ||||||
|   onnx::TensorProto& GetInitializedTensor(std::string name) { |   onnx::TensorProto &GetInitializedTensor(std::string name) { | ||||||
|     assert(nameToInitializedTensor.find(name) != |     assert( | ||||||
|                nameToInitializedTensor.end() && |         nameToInitializedTensor.find(name) != nameToInitializedTensor.end() && | ||||||
|            "Tensor initializer not found"); |         "Tensor initializer not found"); | ||||||
|     return nameToInitializedTensor.at(name); |     return nameToInitializedTensor.at(name); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -127,8 +127,8 @@ private: | ||||||
|    * @param input onnx input tensor ValueInfoProto. |    * @param input onnx input tensor ValueInfoProto. | ||||||
|    * @param symbol mlir input argument. |    * @param symbol mlir input argument. | ||||||
|    */ |    */ | ||||||
|   void ImportInputTensorSymbol(const onnx::ValueInfoProto &input, |   void ImportInputTensorSymbol( | ||||||
|                                mlir::Value symbol) { |       const onnx::ValueInfoProto &input, mlir::Value symbol) { | ||||||
|     auto input_tensor_legalized_name = legalize_name(input.name()); |     auto input_tensor_legalized_name = legalize_name(input.name()); | ||||||
|     assert(!frontend_symbols_.ContainKey(input_tensor_legalized_name) && |     assert(!frontend_symbols_.ContainKey(input_tensor_legalized_name) && | ||||||
|            "Found duplicate legalized input tensor names."); |            "Found duplicate legalized input tensor names."); | ||||||
|  | @ -136,8 +136,7 @@ private: | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   typedef bstd::variant<int64_t, std::vector<int64_t>, float, |   typedef bstd::variant<int64_t, std::vector<int64_t>, float, | ||||||
|                         std::vector<float>, std::string, |       std::vector<float>, std::string, std::vector<std::string>> | ||||||
|                         std::vector<std::string>> |  | ||||||
|       AttrValueType; |       AttrValueType; | ||||||
| 
 | 
 | ||||||
|   struct ONNXAttrVisitor { |   struct ONNXAttrVisitor { | ||||||
|  | @ -213,8 +212,8 @@ private: | ||||||
|     llvm_unreachable("Failed to convert attribute proto to name/value pair"); |     llvm_unreachable("Failed to convert attribute proto to name/value pair"); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   std::vector<mlir::NamedAttribute> |   std::vector<mlir::NamedAttribute> ImportNodeAttributes( | ||||||
|   ImportNodeAttributes(const onnx::NodeProto &node) { |       const onnx::NodeProto &node) { | ||||||
|     std::vector<mlir::NamedAttribute> attributes; |     std::vector<mlir::NamedAttribute> attributes; | ||||||
|     for (int i = 0; i < node.attribute_size(); ++i) { |     for (int i = 0; i < node.attribute_size(); ++i) { | ||||||
|       auto attr = node.attribute(i); |       auto attr = node.attribute(i); | ||||||
|  | @ -281,26 +280,25 @@ private: | ||||||
|     // TODO: Handle optional inputs.
 |     // TODO: Handle optional inputs.
 | ||||||
|     auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes); |     auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes); | ||||||
|     for (int i = 0; i < node.output().size(); i++) { |     for (int i = 0; i < node.output().size(); i++) { | ||||||
|       frontend_symbols_.AddMapping(legalize_name(node.output()[i]), |       frontend_symbols_.AddMapping( | ||||||
|                                    *(op.getODSResults(i).begin())); |           legalize_name(node.output()[i]), *(op.getODSResults(i).begin())); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   template <typename T> |   template <typename T> | ||||||
|   void buildOperation(const onnx::NodeProto &node, |   void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1, | ||||||
|                       int expectedNumOperands = -1, |       int expectedNumResults = -1) { | ||||||
|                       int expectedNumResults = -1) { |  | ||||||
|     std::vector<mlir::Value> inputs; |     std::vector<mlir::Value> inputs; | ||||||
|     for (const auto &item : node.input()) |     for (const auto &item : node.input()) | ||||||
|       if (initializedTensors.ContainKey(legalize_name(item))) { |       if (initializedTensors.ContainKey(legalize_name(item))) { | ||||||
|         inputs.push_back(initializedTensors.EmitInitializerForInputTensor( |         inputs.push_back(initializedTensors.EmitInitializerForInputTensor( | ||||||
|                               UnknownLoc(), builder_, legalize_name(item))); |             UnknownLoc(), builder_, legalize_name(item))); | ||||||
|       } else if (frontend_symbols_.ContainKey(legalize_name(item))) { |       } else if (frontend_symbols_.ContainKey(legalize_name(item))) { | ||||||
|         inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); |         inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); | ||||||
|       } |       } | ||||||
| 
 | 
 | ||||||
|     buildOutputAndOperation<T>(node, inputs, expectedNumOperands, |     buildOutputAndOperation<T>( | ||||||
|         expectedNumResults); |         node, inputs, expectedNumOperands, expectedNumResults); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   void ImportNodeReshape(onnx::NodeProto node, int nIn, int nOut) { |   void ImportNodeReshape(onnx::NodeProto node, int nIn, int nOut) { | ||||||
|  | @ -310,9 +308,8 @@ private: | ||||||
|       item = node.input()[i]; |       item = node.input()[i]; | ||||||
|       // For the second argument, check if there exists an initializer.
 |       // For the second argument, check if there exists an initializer.
 | ||||||
|       if (initializedTensors.ContainKey(legalize_name(item))) { |       if (initializedTensors.ContainKey(legalize_name(item))) { | ||||||
|           inputs.push_back( |         inputs.push_back(initializedTensors.EmitInitializerForInputTensor( | ||||||
|                 initializedTensors.EmitInitializerForInputTensor( |             UnknownLoc(), builder_, legalize_name(item))); | ||||||
|                     UnknownLoc(), builder_, legalize_name(item))); |  | ||||||
|       } else if (frontend_symbols_.ContainKey(legalize_name(item))) { |       } else if (frontend_symbols_.ContainKey(legalize_name(item))) { | ||||||
|         inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); |         inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); | ||||||
|       } |       } | ||||||
|  | @ -372,7 +369,6 @@ private: | ||||||
| #if INCLUDE_ONNX_ML == 1 | #if INCLUDE_ONNX_ML == 1 | ||||||
| #include "src/Builder/MLOpBuildTable.inc" | #include "src/Builder/MLOpBuildTable.inc" | ||||||
| #endif | #endif | ||||||
| 
 |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /*!
 |   /*!
 | ||||||
|  | @ -388,8 +384,8 @@ private: | ||||||
|    *   output tensor. |    *   output tensor. | ||||||
|    */ |    */ | ||||||
|   void ImportOutputTensor(const onnx::ValueInfoProto &output, |   void ImportOutputTensor(const onnx::ValueInfoProto &output, | ||||||
|                           llvm::SmallVectorImpl<mlir::Type> &ret_types, |       llvm::SmallVectorImpl<mlir::Type> &ret_types, | ||||||
|                           llvm::SmallVectorImpl<mlir::Value> &ret_vals) { |       llvm::SmallVectorImpl<mlir::Value> &ret_vals) { | ||||||
|     auto output_tensor_legalized_name = legalize_name(output.name()); |     auto output_tensor_legalized_name = legalize_name(output.name()); | ||||||
|     assert(frontend_symbols_.ContainKey(output_tensor_legalized_name) && |     assert(frontend_symbols_.ContainKey(output_tensor_legalized_name) && | ||||||
|            "Output tensor not found"); |            "Output tensor not found"); | ||||||
|  | @ -400,8 +396,8 @@ private: | ||||||
|     ret_vals.push_back(tensor_val); |     ret_vals.push_back(tensor_val); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   void ImportGraph(const onnx::GraphProto &graph, |   void ImportGraph( | ||||||
|                    const std::string &name = "main_graph") { |       const onnx::GraphProto &graph, const std::string &name = "main_graph") { | ||||||
|     // Maintain a mapping between the parameter and its initializer.
 |     // Maintain a mapping between the parameter and its initializer.
 | ||||||
|     for (auto initializer : graph.initializer()) { |     for (auto initializer : graph.initializer()) { | ||||||
|       auto name = initializer.name(); |       auto name = initializer.name(); | ||||||
|  | @ -426,8 +422,7 @@ private: | ||||||
| 
 | 
 | ||||||
|     // Emit the entry point operation which specifies the number of user
 |     // Emit the entry point operation which specifies the number of user
 | ||||||
|     // inputs and outputs.
 |     // inputs and outputs.
 | ||||||
|     auto entryPoint = mlir::ONNXEntryPointOp::create( |     auto entryPoint = mlir::ONNXEntryPointOp::create(UnknownLoc(), mainFunc, | ||||||
|         UnknownLoc(), mainFunc, |  | ||||||
|         /*numInputs=*/graph.input().size() - graph.initializer().size(), |         /*numInputs=*/graph.input().size() - graph.initializer().size(), | ||||||
|         /*numOutputs=*/graph.output().size()); |         /*numOutputs=*/graph.output().size()); | ||||||
| 
 | 
 | ||||||
|  | @ -454,8 +449,8 @@ private: | ||||||
| 
 | 
 | ||||||
|     // Create a NoneTyped constant to be used for optional operation inputs
 |     // Create a NoneTyped constant to be used for optional operation inputs
 | ||||||
|     // which are not used.
 |     // which are not used.
 | ||||||
|     none_ = builder_.create<mlir::ConstantOp>(UnknownLoc(), |     none_ = | ||||||
|         builder_.getUnitAttr()); |         builder_.create<mlir::ConstantOp>(UnknownLoc(), builder_.getUnitAttr()); | ||||||
| 
 | 
 | ||||||
|     // Import nodes in the graph.
 |     // Import nodes in the graph.
 | ||||||
|     for (const auto &item : graph.node()) { |     for (const auto &item : graph.node()) { | ||||||
|  | @ -483,8 +478,7 @@ private: | ||||||
| namespace onnx_mlir { | namespace onnx_mlir { | ||||||
| 
 | 
 | ||||||
| void ImportFrontendModelFile(std::string model_fname, | void ImportFrontendModelFile(std::string model_fname, | ||||||
|                              mlir::MLIRContext &context, |     mlir::MLIRContext &context, mlir::OwningModuleRef &module) { | ||||||
|                              mlir::OwningModuleRef &module) { |  | ||||||
|   onnx::ModelProto model; |   onnx::ModelProto model; | ||||||
|   std::fstream input(model_fname, std::ios::in | std::ios::binary); |   std::fstream input(model_fname, std::ios::in | std::ios::binary); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -36,11 +36,10 @@ namespace onnx_mlir { | ||||||
|  *  @return MLIR::module generated for the ONNX model. |  *  @return MLIR::module generated for the ONNX model. | ||||||
|  */ |  */ | ||||||
| void ImportFrontendModelFile(std::string model_fname, | void ImportFrontendModelFile(std::string model_fname, | ||||||
|                              mlir::MLIRContext &context, |     mlir::MLIRContext &context, mlir::OwningModuleRef &module); | ||||||
|                              mlir::OwningModuleRef &module); |  | ||||||
| 
 | 
 | ||||||
| /*!
 | /*!
 | ||||||
|  *  TODO: Import models into other extension dialects that cover the |  *  TODO: Import models into other extension dialects that cover the | ||||||
|  *  operations specific to other frameworks such as Tensorflow or Pytorch. |  *  operations specific to other frameworks such as Tensorflow or Pytorch. | ||||||
|  */ |  */ | ||||||
| }  // namespace onnx_mlir
 | } // namespace onnx_mlir
 | ||||||
|  |  | ||||||
|  | @ -1,4 +1,5 @@ | ||||||
| //====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering --------===//
 | //====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering
 | ||||||
|  | //--------===//
 | ||||||
| //
 | //
 | ||||||
| // Copyright 2019 The IBM Research Authors.
 | // Copyright 2019 The IBM Research Authors.
 | ||||||
| //
 | //
 | ||||||
|  | @ -21,10 +22,9 @@ class ONNXEntryPointLowering : public OpRewritePattern<ONNXEntryPointOp> { | ||||||
| public: | public: | ||||||
|   using OpRewritePattern<ONNXEntryPointOp>::OpRewritePattern; |   using OpRewritePattern<ONNXEntryPointOp>::OpRewritePattern; | ||||||
| 
 | 
 | ||||||
|   LogicalResult matchAndRewrite(ONNXEntryPointOp op, |   LogicalResult matchAndRewrite( | ||||||
|                                      PatternRewriter &rewriter) const override { |       ONNXEntryPointOp op, PatternRewriter &rewriter) const override { | ||||||
|     rewriter.replaceOpWithNewOp<KrnlEntryPointOp>( |     rewriter.replaceOpWithNewOp<KrnlEntryPointOp>(op, | ||||||
|         op, |  | ||||||
|         op.getAttrOfType<SymbolRefAttr>( |         op.getAttrOfType<SymbolRefAttr>( | ||||||
|             ONNXEntryPointOp::getEntryPointFuncAttrName()), |             ONNXEntryPointOp::getEntryPointFuncAttrName()), | ||||||
|         op.getAttrOfType<IntegerAttr>(ONNXEntryPointOp::getNumInputsAttrName()), |         op.getAttrOfType<IntegerAttr>(ONNXEntryPointOp::getNumInputsAttrName()), | ||||||
|  | @ -55,8 +55,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() { | ||||||
| 
 | 
 | ||||||
|   // We define the specific operations, or dialects, that are legal targets for
 |   // We define the specific operations, or dialects, that are legal targets for
 | ||||||
|   // this lowering.
 |   // this lowering.
 | ||||||
|   target |   target.addLegalDialect<KrnlOpsDialect, AffineDialect, StandardOpsDialect>(); | ||||||
|       .addLegalDialect<KrnlOpsDialect, AffineDialect, StandardOpsDialect>(); |  | ||||||
| 
 | 
 | ||||||
|   // TODO: enable this once more ops are supported.
 |   // TODO: enable this once more ops are supported.
 | ||||||
|   // We also define the ONNX dialect as Illegal so that the conversion will fail
 |   // We also define the ONNX dialect as Illegal so that the conversion will fail
 | ||||||
|  | @ -81,8 +80,8 @@ void FrontendToKrnlLoweringPass::runOnOperation() { | ||||||
|   // Type conversion for function signatures.
 |   // Type conversion for function signatures.
 | ||||||
|   // Call MLIR FuncOp signature conversion when result type is
 |   // Call MLIR FuncOp signature conversion when result type is
 | ||||||
|   // a ranked tensor.
 |   // a ranked tensor.
 | ||||||
|   populateFuncOpTypeConversionPattern(patterns, &getContext(), |   populateFuncOpTypeConversionPattern( | ||||||
|                                       tensor_to_memref_converter); |       patterns, &getContext(), tensor_to_memref_converter); | ||||||
| 
 | 
 | ||||||
|   // Frontend operation lowering.
 |   // Frontend operation lowering.
 | ||||||
|   // Math
 |   // Math
 | ||||||
|  | @ -119,5 +118,5 @@ std::unique_ptr<Pass> mlir::createLowerToKrnlPass() { | ||||||
|   return std::make_unique<FrontendToKrnlLoweringPass>(); |   return std::make_unique<FrontendToKrnlLoweringPass>(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| static PassRegistration<FrontendToKrnlLoweringPass> | static PassRegistration<FrontendToKrnlLoweringPass> pass( | ||||||
|     pass("lower-frontend", "Lower frontend ops to Krnl dialect."); |     "lower-frontend", "Lower frontend ops to Krnl dialect."); | ||||||
|  |  | ||||||
|  | @ -499,9 +499,8 @@ template <typename ElementwiseUnaryOp> | ||||||
| struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { | struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { | ||||||
|   ONNXElementwiseUnaryOpLowering(MLIRContext *ctx) |   ONNXElementwiseUnaryOpLowering(MLIRContext *ctx) | ||||||
|       : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {} |       : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {} | ||||||
|   LogicalResult |   LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||||||
|   matchAndRewrite(Operation *op, ArrayRef<Value> operands, |       ConversionPatternRewriter &rewriter) const final { | ||||||
|                   ConversionPatternRewriter &rewriter) const final { |  | ||||||
|     // TODO: Check that the types are valid.
 |     // TODO: Check that the types are valid.
 | ||||||
|     // An element-wise unary operation must have all operands and the result of
 |     // An element-wise unary operation must have all operands and the result of
 | ||||||
|     // the same type. This should have been verified by the verifier.
 |     // the same type. This should have been verified by the verifier.
 | ||||||
|  | @ -566,9 +565,8 @@ template <typename ElementwiseVariadicOp> | ||||||
| struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { | struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { | ||||||
|   ONNXElementwiseVariadicOpLowering(MLIRContext *ctx) |   ONNXElementwiseVariadicOpLowering(MLIRContext *ctx) | ||||||
|       : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {} |       : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {} | ||||||
|   LogicalResult |   LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||||||
|   matchAndRewrite(Operation *op, ArrayRef<Value> operands, |       ConversionPatternRewriter &rewriter) const final { | ||||||
|                   ConversionPatternRewriter &rewriter) const final { |  | ||||||
|     // TODO: Check that the types are valid.
 |     // TODO: Check that the types are valid.
 | ||||||
|     // An element-wise variadic operation must have all operands and the result
 |     // An element-wise variadic operation must have all operands and the result
 | ||||||
|     // of the same type. This should have been verified by the verifier.
 |     // of the same type. This should have been verified by the verifier.
 | ||||||
|  |  | ||||||
|  | @ -1,4 +1,5 @@ | ||||||
| //===----------------- Gemm.cpp - Lowering Gemm Op -------------------------===//
 | //===----------------- Gemm.cpp - Lowering Gemm Op
 | ||||||
|  | //-------------------------===//
 | ||||||
| //
 | //
 | ||||||
| // Copyright 2019 The IBM Research Authors.
 | // Copyright 2019 The IBM Research Authors.
 | ||||||
| //
 | //
 | ||||||
|  | @ -17,9 +18,8 @@ struct ONNXGemmOpLowering : public ConversionPattern { | ||||||
|   ONNXGemmOpLowering(MLIRContext *ctx) |   ONNXGemmOpLowering(MLIRContext *ctx) | ||||||
|       : ConversionPattern(GemmOp::getOperationName(), 1, ctx) {} |       : ConversionPattern(GemmOp::getOperationName(), 1, ctx) {} | ||||||
| 
 | 
 | ||||||
|   LogicalResult |   LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||||||
|   matchAndRewrite(Operation *op, ArrayRef<Value> operands, |       ConversionPatternRewriter &rewriter) const final { | ||||||
|                   ConversionPatternRewriter &rewriter) const final { |  | ||||||
|     auto loc = op->getLoc(); |     auto loc = op->getLoc(); | ||||||
|     bool hasBias = !op->getOperand(2).getType().isa<NoneType>(); |     bool hasBias = !op->getOperand(2).getType().isa<NoneType>(); | ||||||
| 
 | 
 | ||||||
|  | @ -32,12 +32,10 @@ struct ONNXGemmOpLowering : public ConversionPattern { | ||||||
| 
 | 
 | ||||||
|     auto memRefType = convertToMemRefType(*op->result_type_begin()); |     auto memRefType = convertToMemRefType(*op->result_type_begin()); | ||||||
| 
 | 
 | ||||||
|     auto alphaAttr = |     auto alphaAttr = FloatAttr::get(memRefType.getElementType(), | ||||||
|         FloatAttr::get(memRefType.getElementType(), |         llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat()); | ||||||
|                        llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat()); |     auto betaAttr = FloatAttr::get(memRefType.getElementType(), | ||||||
|     auto betaAttr = |         llvm::dyn_cast<GemmOp>(op).beta().convertToFloat()); | ||||||
|         FloatAttr::get(memRefType.getElementType(), |  | ||||||
|                        llvm::dyn_cast<GemmOp>(op).beta().convertToFloat()); |  | ||||||
|     auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr); |     auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr); | ||||||
|     auto beta = rewriter.create<ConstantOp>(loc, betaAttr); |     auto beta = rewriter.create<ConstantOp>(loc, betaAttr); | ||||||
| 
 | 
 | ||||||
|  | @ -101,8 +99,8 @@ struct ONNXGemmOpLowering : public ConversionPattern { | ||||||
|     optimizedReductionLoops.reserve(1); |     optimizedReductionLoops.reserve(1); | ||||||
|     reductionLoops.push_back(originalLoops[2]); |     reductionLoops.push_back(originalLoops[2]); | ||||||
|     optimizedReductionLoops.push_back(optimizedLoops[2]); |     optimizedReductionLoops.push_back(optimizedLoops[2]); | ||||||
|     KrnlIterateOperandPack reductionPack(rewriter, reductionLoops, |     KrnlIterateOperandPack reductionPack( | ||||||
|                                          optimizedReductionLoops); |         rewriter, reductionLoops, optimizedReductionLoops); | ||||||
|     // Induction variable for the reduction dimension
 |     // Induction variable for the reduction dimension
 | ||||||
|     // Try to find and use a static value from A or B first.
 |     // Try to find and use a static value from A or B first.
 | ||||||
|     // If it failed then use a dynamic value.
 |     // If it failed then use a dynamic value.
 | ||||||
|  | @ -167,8 +165,8 @@ struct ONNXGemmOpLowering : public ConversionPattern { | ||||||
|     auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs); |     auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs); | ||||||
|     auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB); |     auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB); | ||||||
|     if (hasBias) { |     if (hasBias) { | ||||||
|       auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C, |       auto loopCIVs = getLoopIVsForBroadcasting( | ||||||
|                                                 broadcastedDimInfo); |           loc, rewriter, loopMNIVs, C, broadcastedDimInfo); | ||||||
|       auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs); |       auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs); | ||||||
|       auto betaC = rewriter.create<MulFOp>(loc, beta, loadedC); |       auto betaC = rewriter.create<MulFOp>(loc, beta, loadedC); | ||||||
|       auto Y = rewriter.create<AddFOp>(loc, alphaAB, betaC); |       auto Y = rewriter.create<AddFOp>(loc, alphaAB, betaC); | ||||||
|  | @ -214,7 +212,7 @@ struct ONNXGemmOpLowering : public ConversionPattern { | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns, | void populateLoweringONNXGemmOpPattern( | ||||||
|                                        MLIRContext *ctx) { |     OwningRewritePatternList &patterns, MLIRContext *ctx) { | ||||||
|   patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx); |   patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -16,9 +16,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern { | ||||||
|   ONNXMatMulOpLowering(MLIRContext *ctx) |   ONNXMatMulOpLowering(MLIRContext *ctx) | ||||||
|       : ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {} |       : ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {} | ||||||
| 
 | 
 | ||||||
|   LogicalResult |   LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||||||
|   matchAndRewrite(Operation *op, ArrayRef<Value> operands, |       ConversionPatternRewriter &rewriter) const final { | ||||||
|                   ConversionPatternRewriter &rewriter) const final { |  | ||||||
|     auto loc = op->getLoc(); |     auto loc = op->getLoc(); | ||||||
| 
 | 
 | ||||||
|     ONNXMatMulOpOperandAdaptor operandAdaptor(operands); |     ONNXMatMulOpOperandAdaptor operandAdaptor(operands); | ||||||
|  | @ -119,8 +118,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern { | ||||||
|       // Define loops for batch dimensions.
 |       // Define loops for batch dimensions.
 | ||||||
|       std::vector<Value> originalLoops; |       std::vector<Value> originalLoops; | ||||||
|       std::vector<Value> optimizedLoops; |       std::vector<Value> optimizedLoops; | ||||||
|       Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, |       Block *optimizationBlock = defineLoops( | ||||||
|             optimizedLoops, memRefShape.size()); |           rewriter, loc, originalLoops, optimizedLoops, memRefShape.size()); | ||||||
| 
 | 
 | ||||||
|       // Outer KrnlIterateOp
 |       // Outer KrnlIterateOp
 | ||||||
|       SmallVector<Value, 4> loopBatchIVs; |       SmallVector<Value, 4> loopBatchIVs; | ||||||
|  | @ -139,8 +138,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern { | ||||||
|           outerLoops.push_back(originalLoops[i]); |           outerLoops.push_back(originalLoops[i]); | ||||||
|           optimizedOuterLoops.push_back(optimizedLoops[i]); |           optimizedOuterLoops.push_back(optimizedLoops[i]); | ||||||
|         } |         } | ||||||
|         KrnlIterateOperandPack outerPack(rewriter, outerLoops, |         KrnlIterateOperandPack outerPack( | ||||||
|                                          optimizedOuterLoops); |             rewriter, outerLoops, optimizedOuterLoops); | ||||||
|         for (int i = 0; i < batchAxes.size(); ++i) { |         for (int i = 0; i < batchAxes.size(); ++i) { | ||||||
|           addDimensionToPack(rewriter, loc, outerPack, alloc, i); |           addDimensionToPack(rewriter, loc, outerPack, alloc, i); | ||||||
|         } |         } | ||||||
|  | @ -176,11 +175,11 @@ struct ONNXMatMulOpLowering : public ConversionPattern { | ||||||
|           optimizedMatmulLoops.emplace_back( |           optimizedMatmulLoops.emplace_back( | ||||||
|               optimizedLoops[memRefShape.size() - i]); |               optimizedLoops[memRefShape.size() - i]); | ||||||
|         } |         } | ||||||
|         KrnlIterateOperandPack matmulPack(rewriter, matmulLoops, |         KrnlIterateOperandPack matmulPack( | ||||||
|                                           optimizedMatmulLoops); |             rewriter, matmulLoops, optimizedMatmulLoops); | ||||||
|         for (int i = 2; i > 0; --i) { |         for (int i = 2; i > 0; --i) { | ||||||
|           addDimensionToPack(rewriter, loc, matmulPack, alloc, |           addDimensionToPack( | ||||||
|                              memRefShape.size() - i); |               rewriter, loc, matmulPack, alloc, memRefShape.size() - i); | ||||||
|         } |         } | ||||||
|         matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack); |         matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack); | ||||||
|       } else { |       } else { | ||||||
|  | @ -190,10 +189,10 @@ struct ONNXMatMulOpLowering : public ConversionPattern { | ||||||
|         matmulLoops.emplace_back(originalLoops[memRefShape.size() - 1]); |         matmulLoops.emplace_back(originalLoops[memRefShape.size() - 1]); | ||||||
|         optimizedMatmulLoops.emplace_back( |         optimizedMatmulLoops.emplace_back( | ||||||
|             optimizedLoops[memRefShape.size() - 1]); |             optimizedLoops[memRefShape.size() - 1]); | ||||||
|         KrnlIterateOperandPack matmulPack(rewriter, matmulLoops, |         KrnlIterateOperandPack matmulPack( | ||||||
|                                           optimizedMatmulLoops); |             rewriter, matmulLoops, optimizedMatmulLoops); | ||||||
|         addDimensionToPack(rewriter, loc, matmulPack, alloc, |         addDimensionToPack( | ||||||
|                            memRefShape.size() - 1); |             rewriter, loc, matmulPack, alloc, memRefShape.size() - 1); | ||||||
|         matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack); |         matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack); | ||||||
|       } |       } | ||||||
| 
 | 
 | ||||||
|  | @ -230,8 +229,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern { | ||||||
|       std::vector<Value> optimizedReduceLoops; |       std::vector<Value> optimizedReduceLoops; | ||||||
|       Block *optimizationReduceBlock = |       Block *optimizationReduceBlock = | ||||||
|           defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1); |           defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1); | ||||||
|       KrnlIterateOperandPack reducePack(rewriter, reduceLoops, |       KrnlIterateOperandPack reducePack( | ||||||
|                                         optimizedReduceLoops); |           rewriter, reduceLoops, optimizedReduceLoops); | ||||||
|       addDimensionToPack(rewriter, loc, reducePack, A, AShape.size() - 1); |       addDimensionToPack(rewriter, loc, reducePack, A, AShape.size() - 1); | ||||||
|       auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack); |       auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack); | ||||||
| 
 | 
 | ||||||
|  | @ -292,8 +291,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern { | ||||||
|       std::vector<Value> optimizedReduceLoops; |       std::vector<Value> optimizedReduceLoops; | ||||||
|       Block *optimizationReduceBlock = |       Block *optimizationReduceBlock = | ||||||
|           defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1); |           defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1); | ||||||
|       KrnlIterateOperandPack reducePack(rewriter, reduceLoops, |       KrnlIterateOperandPack reducePack( | ||||||
|                                         optimizedReduceLoops); |           rewriter, reduceLoops, optimizedReduceLoops); | ||||||
|       addDimensionToPack(rewriter, loc, reducePack, A, 0); |       addDimensionToPack(rewriter, loc, reducePack, A, 0); | ||||||
|       auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack); |       auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -102,9 +102,8 @@ struct ONNXReductionOpLowering : public ConversionPattern { | ||||||
|   ONNXReductionOpLowering(MLIRContext *ctx) |   ONNXReductionOpLowering(MLIRContext *ctx) | ||||||
|       : ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {} |       : ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {} | ||||||
| 
 | 
 | ||||||
|   LogicalResult |   LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||||||
|   matchAndRewrite(Operation *op, ArrayRef<Value> operands, |       ConversionPatternRewriter &rewriter) const final { | ||||||
|                   ConversionPatternRewriter &rewriter) const final { |  | ||||||
|     /*
 |     /*
 | ||||||
|      * Condition: reduction function must be associative and commutative. |      * Condition: reduction function must be associative and commutative. | ||||||
|      * |      * | ||||||
|  |  | ||||||
|  | @ -1,4 +1,5 @@ | ||||||
| //===--------------- Conv.cpp - Lowering Convolution Op --------------------===//
 | //===--------------- Conv.cpp - Lowering Convolution Op
 | ||||||
|  | //--------------------===//
 | ||||||
| //
 | //
 | ||||||
| // Copyright 2019 The IBM Research Authors.
 | // Copyright 2019 The IBM Research Authors.
 | ||||||
| //
 | //
 | ||||||
|  | @ -175,14 +176,12 @@ struct ONNXConvOpLowering : public ConversionPattern { | ||||||
| 
 | 
 | ||||||
|         // Emit the bias, if needed.
 |         // Emit the bias, if needed.
 | ||||||
|         if (hasBias) { |         if (hasBias) { | ||||||
|           auto loadResult = |           auto loadResult = rewriter.create<LoadOp>(loc, alloc, resultIndices); | ||||||
|               rewriter.create<LoadOp>(loc, alloc, resultIndices); |  | ||||||
|           SmallVector<Value, 4> biasIndices; |           SmallVector<Value, 4> biasIndices; | ||||||
|           biasIndices.emplace_back(kernel); |           biasIndices.emplace_back(kernel); | ||||||
|           auto loadBias = |           auto loadBias = rewriter.create<LoadOp>(loc, biasOperand, kernel); | ||||||
|               rewriter.create<LoadOp>(loc, biasOperand, kernel); |           auto resultWithBias = | ||||||
|           auto resultWithBias = rewriter.create<MulFOp>( |               rewriter.create<MulFOp>(loc, loadResult, loadBias); | ||||||
|             loc, loadResult, loadBias); |  | ||||||
|           // Store initializer value into output location.
 |           // Store initializer value into output location.
 | ||||||
|           rewriter.create<StoreOp>(loc, resultWithBias, alloc, resultIndices); |           rewriter.create<StoreOp>(loc, resultWithBias, alloc, resultIndices); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  | @ -459,7 +459,8 @@ struct ONNXPoolOpLowering : public ConversionPattern { | ||||||
|         poolDimMap = AffineMap::get(1, 5, dimExpr, rewriter.getContext()); |         poolDimMap = AffineMap::get(1, 5, dimExpr, rewriter.getContext()); | ||||||
| 
 | 
 | ||||||
|         // poolStartMap and poolEndMap
 |         // poolStartMap and poolEndMap
 | ||||||
|         poolStartMap = AffineMap::get(1, 5, {start1, start2}, rewriter.getContext()); |         poolStartMap = | ||||||
|  |             AffineMap::get(1, 5, {start1, start2}, rewriter.getContext()); | ||||||
|         poolEndMap = AffineMap::get(1, 5, {end1, end2}, rewriter.getContext()); |         poolEndMap = AffineMap::get(1, 5, {end1, end2}, rewriter.getContext()); | ||||||
|       } |       } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -36,9 +36,7 @@ MemRefType convertToMemRefType(Type type) { | ||||||
| 
 | 
 | ||||||
| /// Insert an allocation and deallocation for the given MemRefType.
 | /// Insert an allocation and deallocation for the given MemRefType.
 | ||||||
| Value insertAllocAndDealloc(MemRefType type, Location loc, | Value insertAllocAndDealloc(MemRefType type, Location loc, | ||||||
|                                    PatternRewriter &rewriter, |     PatternRewriter &rewriter, bool insertDealloc, ArrayRef<Value> operands) { | ||||||
|                                    bool insertDealloc, |  | ||||||
|                                    ArrayRef<Value> operands) { |  | ||||||
|   // Put together alloc operands for any dynamic dimensions of the memref.
 |   // Put together alloc operands for any dynamic dimensions of the memref.
 | ||||||
|   AllocOp alloc; |   AllocOp alloc; | ||||||
|   if (!operands.empty()) { |   if (!operands.empty()) { | ||||||
|  | @ -64,10 +62,10 @@ Value insertAllocAndDealloc(MemRefType type, Location loc, | ||||||
|           auto operandDim = |           auto operandDim = | ||||||
|               rewriter.create<DimOp>(loc, operands[i], operandDimIdx); |               rewriter.create<DimOp>(loc, operands[i], operandDimIdx); | ||||||
|           if (maxDim) { |           if (maxDim) { | ||||||
|             auto maxCondition = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, |             auto maxCondition = rewriter.create<CmpIOp>( | ||||||
|                                                         operandDim, maxDim); |                 loc, CmpIPredicate::sgt, operandDim, maxDim); | ||||||
|             maxDim = rewriter.create<SelectOp>(loc, maxCondition, operandDim, |             maxDim = rewriter.create<SelectOp>( | ||||||
|                                                maxDim); |                 loc, maxCondition, operandDim, maxDim); | ||||||
|           } else { |           } else { | ||||||
|             maxDim = operandDim; |             maxDim = operandDim; | ||||||
|           } |           } | ||||||
|  | @ -122,8 +120,8 @@ bool checkInsertDealloc(Operation *currentOp, int resultIndex) { | ||||||
| // Create a mapping from result type's dimensions to input type's dimensions,
 | // Create a mapping from result type's dimensions to input type's dimensions,
 | ||||||
| // given that the result type is the result of a reduction op over the input
 | // given that the result type is the result of a reduction op over the input
 | ||||||
| // type.
 | // type.
 | ||||||
| std::map<int64_t, int64_t> | std::map<int64_t, int64_t> getReductionMapping( | ||||||
| getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) { |     MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) { | ||||||
|   std::map<int64_t, int64_t> OutInDimMap; |   std::map<int64_t, int64_t> OutInDimMap; | ||||||
|   int64_t rank = inputTy.getRank(); |   int64_t rank = inputTy.getRank(); | ||||||
| 
 | 
 | ||||||
|  | @ -152,9 +150,8 @@ getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) { | ||||||
| 
 | 
 | ||||||
| // Add bounds associated with the op operand to the KRNL iteration pack.
 | // Add bounds associated with the op operand to the KRNL iteration pack.
 | ||||||
| // Dynamic dimenions are supported.
 | // Dynamic dimenions are supported.
 | ||||||
| void addDimensionToPack(ConversionPatternRewriter &rewriter, | void addDimensionToPack(ConversionPatternRewriter &rewriter, Location loc, | ||||||
|                                Location loc, KrnlIterateOperandPack &pack, |     KrnlIterateOperandPack &pack, Value operand, int index) { | ||||||
|                                Value operand, int index) { |  | ||||||
|   auto shape = operand.getType().cast<MemRefType>().getShape(); |   auto shape = operand.getType().cast<MemRefType>().getShape(); | ||||||
|   if (shape[index] < 0) { |   if (shape[index] < 0) { | ||||||
|     pack.pushConstantBound(0); |     pack.pushConstantBound(0); | ||||||
|  | @ -168,10 +165,9 @@ void addDimensionToPack(ConversionPatternRewriter &rewriter, | ||||||
| 
 | 
 | ||||||
| // Function that defines the KRNL dialect loops and their respective
 | // Function that defines the KRNL dialect loops and their respective
 | ||||||
| // optimized version.
 | // optimized version.
 | ||||||
| KrnlOptimizeLoopsOp | KrnlOptimizeLoopsOp emitOptimizedLoops(ConversionPatternRewriter &rewriter, | ||||||
| emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc, |     Location loc, std::vector<Value> &loops, std::vector<Value> &optimizedLoops, | ||||||
|                    std::vector<Value> &loops, |     int64_t numLoops) { | ||||||
|                    std::vector<Value> &optimizedLoops, int64_t numLoops) { |  | ||||||
|   // Define loops.
 |   // Define loops.
 | ||||||
|   auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops); |   auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops); | ||||||
|   loops.reserve(numLoops); |   loops.reserve(numLoops); | ||||||
|  | @ -190,9 +186,8 @@ emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc, | ||||||
| // Function that emits the loops and their optimized version.
 | // Function that emits the loops and their optimized version.
 | ||||||
| // The function returns a reference to the inner optimization block.
 | // The function returns a reference to the inner optimization block.
 | ||||||
| Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc, | Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc, | ||||||
|                           std::vector<Value> &loops, |     std::vector<Value> &loops, std::vector<Value> &optimizedLoops, | ||||||
|                           std::vector<Value> &optimizedLoops, |     int64_t numLoops) { | ||||||
|                           int64_t numLoops) { |  | ||||||
|   KrnlOptimizeLoopsOp optimizedLoopsOp = |   KrnlOptimizeLoopsOp optimizedLoopsOp = | ||||||
|       emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops); |       emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops); | ||||||
|   return &optimizedLoopsOp.region().front(); |   return &optimizedLoopsOp.region().front(); | ||||||
|  | @ -201,10 +196,9 @@ Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc, | ||||||
| // Function which emits a basic set of loops and optimized loops
 | // Function which emits a basic set of loops and optimized loops
 | ||||||
| // for a given operation argument. A reference to the loop optimization
 | // for a given operation argument. A reference to the loop optimization
 | ||||||
| // block is returned in the last argument of the function.
 | // block is returned in the last argument of the function.
 | ||||||
| void emitKrnlLoopsAndIterationForOperand( | void emitKrnlLoopsAndIterationForOperand(ConversionPatternRewriter &rewriter, | ||||||
|     ConversionPatternRewriter &rewriter, Location loc, Value operand, |     Location loc, Value operand, std::vector<Value> &originalLoops, | ||||||
|     std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp, |     KrnlOptimizeLoopsOp &optimizedLoopsOp, KrnlIterateOp &iterateOp) { | ||||||
|     KrnlIterateOp &iterateOp) { |  | ||||||
|   // Operand shape.
 |   // Operand shape.
 | ||||||
|   auto shape = operand.getType().cast<MemRefType>().getShape(); |   auto shape = operand.getType().cast<MemRefType>().getShape(); | ||||||
| 
 | 
 | ||||||
|  | @ -240,9 +234,9 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { | ||||||
| 
 | 
 | ||||||
| // Get run-time dimension information for unknown dimensions used for
 | // Get run-time dimension information for unknown dimensions used for
 | ||||||
| // broadcasting.
 | // broadcasting.
 | ||||||
| std::map<int, std::map<int, Value>> | std::map<int, std::map<int, Value>> getBroadcastedDimInfo(Location loc, | ||||||
| getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, |     ConversionPatternRewriter &rewriter, MemRefType memRefType, | ||||||
|                       MemRefType memRefType, ArrayRef<Value> operands) { |     ArrayRef<Value> operands) { | ||||||
|   auto memRefShape = memRefType.getShape(); |   auto memRefShape = memRefType.getShape(); | ||||||
|   int64_t rank = memRefShape.size(); |   int64_t rank = memRefShape.size(); | ||||||
|   // For unknown dimensions, we need to get dimension values at runtime in
 |   // For unknown dimensions, we need to get dimension values at runtime in
 | ||||||
|  | @ -286,10 +280,9 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, | ||||||
| 
 | 
 | ||||||
| // Extract induction variables that are used for broadcasting values of a
 | // Extract induction variables that are used for broadcasting values of a
 | ||||||
| // given operand.
 | // given operand.
 | ||||||
| std::vector<Value> | std::vector<Value> getLoopIVsForBroadcasting(Location loc, | ||||||
| getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, |     ConversionPatternRewriter &rewriter, ArrayRef<Value> loopIVs, Value operand, | ||||||
|                           ArrayRef<Value> loopIVs, Value operand, |     std::map<int, Value> broadcastedDims) { | ||||||
|                           std::map<int, Value> broadcastedDims) { |  | ||||||
|   // `operand` must has a ranked type. This should have been checked by the
 |   // `operand` must has a ranked type. This should have been checked by the
 | ||||||
|   // shape inference pass.
 |   // shape inference pass.
 | ||||||
|   auto operandShape = operand.getType().cast<MemRefType>().getShape(); |   auto operandShape = operand.getType().cast<MemRefType>().getShape(); | ||||||
|  | @ -310,8 +303,8 @@ getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, | ||||||
|       // If its value is 1, it is broadcasted dimension.
 |       // If its value is 1, it is broadcasted dimension.
 | ||||||
|       // Otherwise, non-broadcasted dimension.
 |       // Otherwise, non-broadcasted dimension.
 | ||||||
|       auto zero = rewriter.create<ConstantIndexOp>(loc, 0); |       auto zero = rewriter.create<ConstantIndexOp>(loc, 0); | ||||||
|       auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], zero, |       auto idx = rewriter.create<SelectOp>( | ||||||
|                                            loopIVs[loopIdx]); |           loc, broadcastedDims[dimIdx], zero, loopIVs[loopIdx]); | ||||||
|       newLoopIVs.insert(newLoopIVs.begin(), idx); |       newLoopIVs.insert(newLoopIVs.begin(), idx); | ||||||
|     } else { |     } else { | ||||||
|       // Non-broadcasted dimension
 |       // Non-broadcasted dimension
 | ||||||
|  |  | ||||||
|  | @ -30,7 +30,7 @@ struct ONNXConcatOpLowering : public ConversionPattern { | ||||||
|     auto memRefType = convertToMemRefType(*op->result_type_begin()); |     auto memRefType = convertToMemRefType(*op->result_type_begin()); | ||||||
|     auto resultShape = memRefType.getShape(); |     auto resultShape = memRefType.getShape(); | ||||||
|     auto rank = resultShape.size(); |     auto rank = resultShape.size(); | ||||||
|     assert((axis >=0 && axis < rank) && "Concat axis out of bounds"); |     assert((axis >= 0 && axis < rank) && "Concat axis out of bounds"); | ||||||
| 
 | 
 | ||||||
|     if (hasAllConstantDimensions(memRefType)) |     if (hasAllConstantDimensions(memRefType)) | ||||||
|       alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); |       alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); | ||||||
|  |  | ||||||
|  | @ -17,8 +17,8 @@ struct ONNXConstantOpLowering : public ConversionPattern { | ||||||
| 
 | 
 | ||||||
|   ONNXConstantOpLowering(MLIRContext *ctx) |   ONNXConstantOpLowering(MLIRContext *ctx) | ||||||
|       : ConversionPattern(mlir::ONNXConstantOp::getOperationName(), 1, ctx) { |       : ConversionPattern(mlir::ONNXConstantOp::getOperationName(), 1, ctx) { | ||||||
|         constantID = 0; |     constantID = 0; | ||||||
|       } |   } | ||||||
| 
 | 
 | ||||||
|   LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, |   LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||||||
|       ConversionPatternRewriter &rewriter) const final { |       ConversionPatternRewriter &rewriter) const final { | ||||||
|  | @ -34,12 +34,11 @@ struct ONNXConstantOpLowering : public ConversionPattern { | ||||||
|     // Shape based computations.
 |     // Shape based computations.
 | ||||||
|     auto shape = memRefType.getShape(); |     auto shape = memRefType.getShape(); | ||||||
|     int64_t numElements = 1; |     int64_t numElements = 1; | ||||||
|     for (int i=0; i<shape.size(); ++i) |     for (int i = 0; i < shape.size(); ++i) | ||||||
|       numElements *= shape[i]; |       numElements *= shape[i]; | ||||||
| 
 | 
 | ||||||
|     // Emit the constant global in Krnl dialect.
 |     // Emit the constant global in Krnl dialect.
 | ||||||
|     auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc, |     auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc, memRefType, | ||||||
|         memRefType, |  | ||||||
|         rewriter.getI64ArrayAttr(shape), |         rewriter.getI64ArrayAttr(shape), | ||||||
|         rewriter.getStringAttr("constant_" + std::to_string(constantID)), |         rewriter.getStringAttr("constant_" + std::to_string(constantID)), | ||||||
|         constantOp.value().getValue()); |         constantOp.value().getValue()); | ||||||
|  |  | ||||||
|  | @ -24,15 +24,15 @@ enum Kinds { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| class LoopType : public mlir::Type::TypeBase<LoopType, mlir::Type> { | class LoopType : public mlir::Type::TypeBase<LoopType, mlir::Type> { | ||||||
|  public: | public: | ||||||
|   using Base::Base; |   using Base::Base; | ||||||
| 
 | 
 | ||||||
|   // Support type inquiry through isa, cast and dyn_cast.
 |   // Support type inquiry through isa, cast and dyn_cast.
 | ||||||
|   static bool kindof(unsigned kind) { return kind == KrnlTypes::Loop; } |   static bool kindof(unsigned kind) { return kind == KrnlTypes::Loop; } | ||||||
| 
 | 
 | ||||||
|   // Get a unique instance of Loop type.
 |   // Get a unique instance of Loop type.
 | ||||||
|   static LoopType get(mlir::MLIRContext* context) { |   static LoopType get(mlir::MLIRContext *context) { | ||||||
|     return Base::get(context, KrnlTypes::Loop); |     return Base::get(context, KrnlTypes::Loop); | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
| }  // namespace mlir
 | } // namespace mlir
 | ||||||
|  |  | ||||||
|  | @ -39,7 +39,6 @@ MLONNXOpsDialect::MLONNXOpsDialect(mlir::MLIRContext *ctx) | ||||||
|       >(); |       >(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
| // TableGen'd op method definitions
 | // TableGen'd op method definitions
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
|  |  | ||||||
|  | @ -19,14 +19,14 @@ | ||||||
| #include "mlir/IR/OpDefinition.h" | #include "mlir/IR/OpDefinition.h" | ||||||
| #include "mlir/IR/StandardTypes.h" | #include "mlir/IR/StandardTypes.h" | ||||||
| 
 | 
 | ||||||
| #include "src/Interface/ShapeInferenceInterface.hpp" |  | ||||||
| #include "src/Interface/PromotableConstOperandsOpInterface.hpp" | #include "src/Interface/PromotableConstOperandsOpInterface.hpp" | ||||||
|  | #include "src/Interface/ShapeInferenceInterface.hpp" | ||||||
| 
 | 
 | ||||||
| namespace mlir { | namespace mlir { | ||||||
| 
 | 
 | ||||||
| class MLONNXOpsDialect : public Dialect { | class MLONNXOpsDialect : public Dialect { | ||||||
|  public: | public: | ||||||
|   MLONNXOpsDialect(MLIRContext* context); |   MLONNXOpsDialect(MLIRContext *context); | ||||||
| 
 | 
 | ||||||
|   /// Provide a utility accessor to the dialect namespace. This is used by
 |   /// Provide a utility accessor to the dialect namespace. This is used by
 | ||||||
|   /// several utilities for casting between dialects.
 |   /// several utilities for casting between dialects.
 | ||||||
|  | @ -38,6 +38,6 @@ class MLONNXOpsDialect : public Dialect { | ||||||
| #define GET_OP_CLASSES | #define GET_OP_CLASSES | ||||||
| #include "src/Dialect/MLONNX/MLONNXOps.hpp.inc" | #include "src/Dialect/MLONNX/MLONNXOps.hpp.inc" | ||||||
| 
 | 
 | ||||||
| }  // end namespace mlir
 | } // end namespace mlir
 | ||||||
| 
 | 
 | ||||||
| namespace onnx_mlir {} | namespace onnx_mlir {} | ||||||
|  |  | ||||||
|  | @ -19,14 +19,14 @@ | ||||||
| #include "mlir/IR/OpDefinition.h" | #include "mlir/IR/OpDefinition.h" | ||||||
| #include "mlir/IR/StandardTypes.h" | #include "mlir/IR/StandardTypes.h" | ||||||
| 
 | 
 | ||||||
| #include "src/Interface/ShapeInferenceInterface.hpp" |  | ||||||
| #include "src/Interface/PromotableConstOperandsOpInterface.hpp" | #include "src/Interface/PromotableConstOperandsOpInterface.hpp" | ||||||
|  | #include "src/Interface/ShapeInferenceInterface.hpp" | ||||||
| 
 | 
 | ||||||
| namespace mlir { | namespace mlir { | ||||||
| 
 | 
 | ||||||
| class ONNXOpsDialect : public Dialect { | class ONNXOpsDialect : public Dialect { | ||||||
|  public: | public: | ||||||
|   ONNXOpsDialect(MLIRContext* context); |   ONNXOpsDialect(MLIRContext *context); | ||||||
| 
 | 
 | ||||||
|   /// Provide a utility accessor to the dialect namespace. This is used by
 |   /// Provide a utility accessor to the dialect namespace. This is used by
 | ||||||
|   /// several utilities for casting between dialects.
 |   /// several utilities for casting between dialects.
 | ||||||
|  | @ -38,6 +38,6 @@ class ONNXOpsDialect : public Dialect { | ||||||
| #define GET_OP_CLASSES | #define GET_OP_CLASSES | ||||||
| #include "src/Dialect/ONNX/ONNXOps.hpp.inc" | #include "src/Dialect/ONNX/ONNXOps.hpp.inc" | ||||||
| 
 | 
 | ||||||
| }  // end namespace mlir
 | } // end namespace mlir
 | ||||||
| 
 | 
 | ||||||
| namespace onnx_mlir {} | namespace onnx_mlir {} | ||||||
|  |  | ||||||
|  | @ -10,7 +10,6 @@ | ||||||
| //
 | //
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| #include "src/Interface/PromotableConstOperandsOpInterface.hpp" | #include "src/Interface/PromotableConstOperandsOpInterface.hpp" | ||||||
| 
 | 
 | ||||||
| using namespace mlir; | using namespace mlir; | ||||||
|  | @ -20,4 +19,3 @@ using namespace mlir; | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
| 
 | 
 | ||||||
| #include "src/Interface/PromotableConstOperandsOpInterface.cpp.inc" | #include "src/Interface/PromotableConstOperandsOpInterface.cpp.inc" | ||||||
| 
 |  | ||||||
|  |  | ||||||
|  | @ -22,4 +22,4 @@ namespace mlir { | ||||||
| /// Include the auto-generated declarations.
 | /// Include the auto-generated declarations.
 | ||||||
| #include "src/Interface/PromotableConstOperandsOpInterface.hpp.inc" | #include "src/Interface/PromotableConstOperandsOpInterface.hpp.inc" | ||||||
| 
 | 
 | ||||||
| }  // end namespace mlir
 | } // end namespace mlir
 | ||||||
|  | @ -16,4 +16,4 @@ namespace mlir { | ||||||
| /// Include the auto-generated declarations.
 | /// Include the auto-generated declarations.
 | ||||||
| #include "src/Interface/ShapeInference.cpp.inc" | #include "src/Interface/ShapeInference.cpp.inc" | ||||||
| 
 | 
 | ||||||
| }  // end namespace mlir
 | } // end namespace mlir
 | ||||||
|  |  | ||||||
|  | @ -18,4 +18,4 @@ namespace mlir { | ||||||
| /// Include the auto-generated declarations.
 | /// Include the auto-generated declarations.
 | ||||||
| #include "src/Interface/ShapeInference.hpp.inc" | #include "src/Interface/ShapeInference.hpp.inc" | ||||||
| 
 | 
 | ||||||
| }  // end namespace mlir
 | } // end namespace mlir
 | ||||||
|  |  | ||||||
|  | @ -14,7 +14,7 @@ | ||||||
| 
 | 
 | ||||||
| #ifdef _WIN32 | #ifdef _WIN32 | ||||||
| #include <io.h> | #include <io.h> | ||||||
| #else  | #else | ||||||
| #include <unistd.h> | #include <unistd.h> | ||||||
| #endif | #endif | ||||||
| 
 | 
 | ||||||
|  | @ -22,7 +22,7 @@ using namespace std; | ||||||
| using namespace onnx_mlir; | using namespace onnx_mlir; | ||||||
| 
 | 
 | ||||||
| void LoadMLIR(string inputFilename, mlir::MLIRContext &context, | void LoadMLIR(string inputFilename, mlir::MLIRContext &context, | ||||||
|               mlir::OwningModuleRef &module) { |     mlir::OwningModuleRef &module) { | ||||||
|   // Handle '.mlir' input to the ONNX MLIR frontend.
 |   // Handle '.mlir' input to the ONNX MLIR frontend.
 | ||||||
|   // The mlir format indicates that one or more of the supported
 |   // The mlir format indicates that one or more of the supported
 | ||||||
|   // representations are used in the file.
 |   // representations are used in the file.
 | ||||||
|  | @ -46,10 +46,10 @@ void LoadMLIR(string inputFilename, mlir::MLIRContext &context, | ||||||
| void EmitLLVMBitCode( | void EmitLLVMBitCode( | ||||||
|     const mlir::OwningModuleRef &module, string outputFilename) { |     const mlir::OwningModuleRef &module, string outputFilename) { | ||||||
|   error_code error; |   error_code error; | ||||||
|   llvm::raw_fd_ostream moduleBitcodeStream(outputFilename, error, |   llvm::raw_fd_ostream moduleBitcodeStream( | ||||||
|                                            llvm::sys::fs::F_None); |       outputFilename, error, llvm::sys::fs::F_None); | ||||||
|   llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module), |   llvm::WriteBitcodeToFile( | ||||||
|                            moduleBitcodeStream); |       *mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream); | ||||||
|   moduleBitcodeStream.flush(); |   moduleBitcodeStream.flush(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -90,7 +90,7 @@ void addKrnlToLLVMPasses(mlir::PassManager &pm) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void processInputFile(string inputFilename, EmissionTargetType emissionTarget, | void processInputFile(string inputFilename, EmissionTargetType emissionTarget, | ||||||
| 	mlir::MLIRContext &context, mlir::OwningModuleRef &module) { |     mlir::MLIRContext &context, mlir::OwningModuleRef &module) { | ||||||
|   // Decide if the input file is an ONNX model or a model specified
 |   // Decide if the input file is an ONNX model or a model specified
 | ||||||
|   // in MLIR. The extension of the file is the decider.
 |   // in MLIR. The extension of the file is the decider.
 | ||||||
|   string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1); |   string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1); | ||||||
|  | @ -99,7 +99,6 @@ void processInputFile(string inputFilename, EmissionTargetType emissionTarget, | ||||||
|   assert(inputIsONNX != inputIsMLIR && |   assert(inputIsONNX != inputIsMLIR && | ||||||
|          "Either ONNX model or MLIR file needs to be provided."); |          "Either ONNX model or MLIR file needs to be provided."); | ||||||
| 
 | 
 | ||||||
|    |  | ||||||
|   if (inputIsONNX) { |   if (inputIsONNX) { | ||||||
|     ImportFrontendModelFile(inputFilename, context, module); |     ImportFrontendModelFile(inputFilename, context, module); | ||||||
|   } else { |   } else { | ||||||
|  | @ -119,8 +118,8 @@ void outputCode( | ||||||
|   module->dump(); |   module->dump(); | ||||||
|   fflush(stderr); |   fflush(stderr); | ||||||
|   // set modified stderr as original stderr
 |   // set modified stderr as original stderr
 | ||||||
|   _dup2(stderrOrigin, _fileno( stderr )); |   _dup2(stderrOrigin, _fileno(stderr)); | ||||||
| #else  | #else | ||||||
|   if (fork() == 0) { |   if (fork() == 0) { | ||||||
|     freopen(tempFilename.c_str(), "w", stderr); |     freopen(tempFilename.c_str(), "w", stderr); | ||||||
|     module->dump(); |     module->dump(); | ||||||
|  | @ -151,7 +150,7 @@ void emitOutputFiles(string outputBaseName, EmissionTargetType emissionTarget, | ||||||
|   // necessary when emitting the .bc file.
 |   // necessary when emitting the .bc file.
 | ||||||
|   if (emissionTarget == EmitLLVMBC) { |   if (emissionTarget == EmitLLVMBC) { | ||||||
|     // Write LLVM bitcode to disk.
 |     // Write LLVM bitcode to disk.
 | ||||||
|     string outputFilename =  outputBaseName + ".bc"; |     string outputFilename = outputBaseName + ".bc"; | ||||||
|     EmitLLVMBitCode(module, outputFilename); |     EmitLLVMBitCode(module, outputFilename); | ||||||
|     printf("LLVM bitcode written to %s\n", outputFilename.c_str()); |     printf("LLVM bitcode written to %s\n", outputFilename.c_str()); | ||||||
|   } else { |   } else { | ||||||
|  |  | ||||||
|  | @ -28,9 +28,9 @@ | ||||||
| #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" | #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" | ||||||
| #include "mlir/ExecutionEngine/ExecutionEngine.h" | #include "mlir/ExecutionEngine/ExecutionEngine.h" | ||||||
| #include "mlir/ExecutionEngine/OptUtils.h" | #include "mlir/ExecutionEngine/OptUtils.h" | ||||||
| #include "mlir/InitAllDialects.h" |  | ||||||
| #include "mlir/IR/MLIRContext.h" | #include "mlir/IR/MLIRContext.h" | ||||||
| #include "mlir/IR/Module.h" | #include "mlir/IR/Module.h" | ||||||
|  | #include "mlir/InitAllDialects.h" | ||||||
| #include "mlir/Parser.h" | #include "mlir/Parser.h" | ||||||
| #include "mlir/Pass/Pass.h" | #include "mlir/Pass/Pass.h" | ||||||
| #include "mlir/Pass/PassManager.h" | #include "mlir/Pass/PassManager.h" | ||||||
|  | @ -46,10 +46,10 @@ enum EmissionTargetType { | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| void LoadMLIR(std::string inputFilename, mlir::MLIRContext &context, | void LoadMLIR(std::string inputFilename, mlir::MLIRContext &context, | ||||||
|               mlir::OwningModuleRef &module); |     mlir::OwningModuleRef &module); | ||||||
| 
 | 
 | ||||||
| void EmitLLVMBitCode( | void EmitLLVMBitCode( | ||||||
| 	const mlir::OwningModuleRef &module, std::string outputFilename); |     const mlir::OwningModuleRef &module, std::string outputFilename); | ||||||
| 
 | 
 | ||||||
| void registerDialects(); | void registerDialects(); | ||||||
| 
 | 
 | ||||||
|  | @ -66,8 +66,7 @@ void processInputFile(std::string inputFilename, | ||||||
|     mlir::OwningModuleRef &module); |     mlir::OwningModuleRef &module); | ||||||
| 
 | 
 | ||||||
| void outputCode( | void outputCode( | ||||||
|     mlir::OwningModuleRef &module, std::string filename, |     mlir::OwningModuleRef &module, std::string filename, std::string extension); | ||||||
|     std::string extension); |  | ||||||
| 
 | 
 | ||||||
| void emitOutputFiles(std::string outputBaseName, | void emitOutputFiles(std::string outputBaseName, | ||||||
|     EmissionTargetType emissionTarget, mlir::MLIRContext &context, |     EmissionTargetType emissionTarget, mlir::MLIRContext &context, | ||||||
|  |  | ||||||
|  | @ -38,4 +38,4 @@ std::unique_ptr<Pass> createElideConstGlobalValuePass(); | ||||||
| /// Pass for lowering Krnl dialect to LLVM dialect.
 | /// Pass for lowering Krnl dialect to LLVM dialect.
 | ||||||
| std::unique_ptr<Pass> createKrnlLowerToLLVMPass(); | std::unique_ptr<Pass> createKrnlLowerToLLVMPass(); | ||||||
| 
 | 
 | ||||||
| }  // end namespace mlir
 | } // end namespace mlir
 | ||||||
|  |  | ||||||
|  | @ -1,15 +1,15 @@ | ||||||
| enum DYN_MEMREF_DATA_TYPE { | enum DYN_MEMREF_DATA_TYPE { | ||||||
|   UNDEFINED = 0; |   UNDEFINED = 0; | ||||||
|   // Basic types.
 |   // Basic types.
 | ||||||
|   FLOAT = 1;   // float
 |   FLOAT = 1;  // float
 | ||||||
|   UINT8 = 2;   // uint8_t
 |   UINT8 = 2;  // uint8_t
 | ||||||
|   INT8 = 3;    // int8_t
 |   INT8 = 3;   // int8_t
 | ||||||
|   UINT16 = 4;  // uint16_t
 |   UINT16 = 4; // uint16_t
 | ||||||
|   INT16 = 5;   // int16_t
 |   INT16 = 5;  // int16_t
 | ||||||
|   INT32 = 6;   // int32_t
 |   INT32 = 6;  // int32_t
 | ||||||
|   INT64 = 7;   // int64_t
 |   INT64 = 7;  // int64_t
 | ||||||
|   STRING = 8;  // string
 |   STRING = 8; // string
 | ||||||
|   BOOL = 9;    // bool
 |   BOOL = 9;   // bool
 | ||||||
| 
 | 
 | ||||||
|   // IEEE754 half-precision floating-point format (16 bits wide).
 |   // IEEE754 half-precision floating-point format (16 bits wide).
 | ||||||
|   // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
 |   // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
 | ||||||
|  | @ -18,8 +18,8 @@ enum DYN_MEMREF_DATA_TYPE { | ||||||
|   DOUBLE = 11; |   DOUBLE = 11; | ||||||
|   UINT32 = 12; |   UINT32 = 12; | ||||||
|   UINT64 = 13; |   UINT64 = 13; | ||||||
|   COMPLEX64 = 14;     // complex with float32 real and imaginary components
 |   COMPLEX64 = 14;  // complex with float32 real and imaginary components
 | ||||||
|   COMPLEX128 = 15;    // complex with float64 real and imaginary components
 |   COMPLEX128 = 15; // complex with float64 real and imaginary components
 | ||||||
| 
 | 
 | ||||||
|   // Non-IEEE floating-point format based on IEEE754 single-precision
 |   // Non-IEEE floating-point format based on IEEE754 single-precision
 | ||||||
|   // floating-point number truncated to 16 bits.
 |   // floating-point number truncated to 16 bits.
 | ||||||
|  |  | ||||||
|  | @ -33,8 +33,8 @@ DynMemRef *getDynMemRef(OrderedDynMemRefDict *tensorDict, int idx) { | ||||||
|   return tensorDict->tensorDict[tensorDict->orderedNames[idx]]; |   return tensorDict->tensorDict[tensorDict->orderedNames[idx]]; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx, | void setDynMemRef( | ||||||
|                   DynMemRef *tensor) { |     OrderedDynMemRefDict *tensorDict, int idx, DynMemRef *tensor) { | ||||||
|   if (tensorDict->orderedNames.size() <= idx) |   if (tensorDict->orderedNames.size() <= idx) | ||||||
|     tensorDict->orderedNames.resize(idx + 1); |     tensorDict->orderedNames.resize(idx + 1); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -38,7 +38,6 @@ typedef struct _OrderedDynMemRefDict OrderedDynMemRefDict; | ||||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||||
| extern "C" { | extern "C" { | ||||||
| #endif | #endif | ||||||
|    |  | ||||||
| 
 | 
 | ||||||
| // Get number of dynamic memrefs in OrderedDynMemRefDict dict.
 | // Get number of dynamic memrefs in OrderedDynMemRefDict dict.
 | ||||||
| int numDynMemRefs(OrderedDynMemRefDict *dict); | int numDynMemRefs(OrderedDynMemRefDict *dict); | ||||||
|  | @ -53,8 +52,8 @@ DynMemRef *createDynMemRef(int rank); | ||||||
| DynMemRef *getDynMemRef(OrderedDynMemRefDict *orderedDict, int i); | DynMemRef *getDynMemRef(OrderedDynMemRefDict *orderedDict, int i); | ||||||
| 
 | 
 | ||||||
| // Set the i-th dynmemref in orderedDict to be dynMemRef.
 | // Set the i-th dynmemref in orderedDict to be dynMemRef.
 | ||||||
| void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx, | void setDynMemRef( | ||||||
|                   DynMemRef *dynMemRef); |     OrderedDynMemRefDict *tensorDict, int idx, DynMemRef *dynMemRef); | ||||||
| 
 | 
 | ||||||
| // Get data pointer from dynMemRef.
 | // Get data pointer from dynMemRef.
 | ||||||
| void *getData(DynMemRef *dynMemRef); | void *getData(DynMemRef *dynMemRef); | ||||||
|  |  | ||||||
|  | @ -1,14 +1,14 @@ | ||||||
| #include "Runtime.hpp" | #include "Runtime.hpp" | ||||||
| 
 | 
 | ||||||
| ExecutionSession::ExecutionSession(std::string sharedLibPath, | ExecutionSession::ExecutionSession( | ||||||
|                                    std::string entryPointName) { |     std::string sharedLibPath, std::string entryPointName) { | ||||||
|   _sharedLibraryHandle = dlopen(sharedLibPath.c_str(), RTLD_LAZY); |   _sharedLibraryHandle = dlopen(sharedLibPath.c_str(), RTLD_LAZY); | ||||||
|   _entryPointFunc = |   _entryPointFunc = | ||||||
|       (entryPointFuncType)dlsym(_sharedLibraryHandle, entryPointName.c_str()); |       (entryPointFuncType)dlsym(_sharedLibraryHandle, entryPointName.c_str()); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| std::vector<py::array> | std::vector<py::array> ExecutionSession::run( | ||||||
| ExecutionSession::run(std::vector<py::array> inputsPyArray) { |     std::vector<py::array> inputsPyArray) { | ||||||
|   assert(_entryPointFunc && "entry point not loaded"); |   assert(_entryPointFunc && "entry point not loaded"); | ||||||
|   auto *wrappedInput = createOrderedDynMemRefDict(); |   auto *wrappedInput = createOrderedDynMemRefDict(); | ||||||
|   int inputIdx = 0; |   int inputIdx = 0; | ||||||
|  | @ -40,8 +40,8 @@ ExecutionSession::run(std::vector<py::array> inputsPyArray) { | ||||||
|   auto *wrappedOutput = _entryPointFunc(wrappedInput); |   auto *wrappedOutput = _entryPointFunc(wrappedInput); | ||||||
|   for (int i = 0; i < numDynMemRefs(wrappedOutput); i++) { |   for (int i = 0; i < numDynMemRefs(wrappedOutput); i++) { | ||||||
|     auto *dynMemRef = getDynMemRef(wrappedOutput, i); |     auto *dynMemRef = getDynMemRef(wrappedOutput, i); | ||||||
|     auto shape = std::vector<int64_t>(dynMemRef->sizes, |     auto shape = std::vector<int64_t>( | ||||||
|                                       dynMemRef->sizes + dynMemRef->rank); |         dynMemRef->sizes, dynMemRef->sizes + dynMemRef->rank); | ||||||
|     outputPyArrays.emplace_back( |     outputPyArrays.emplace_back( | ||||||
|         py::array(py::dtype("float32"), shape, dynMemRef->data)); |         py::array(py::dtype("float32"), shape, dynMemRef->data)); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  | @ -144,9 +144,9 @@ public: | ||||||
| 
 | 
 | ||||||
|       assert(krnlGlobalOp.value().hasValue() && |       assert(krnlGlobalOp.value().hasValue() && | ||||||
|              "Krnl Global must always have a value"); |              "Krnl Global must always have a value"); | ||||||
|       global = rewriter.create<LLVM::GlobalOp>(loc, |       global = rewriter.create<LLVM::GlobalOp>(loc, llvmGlobalType, | ||||||
|           llvmGlobalType, /*isConstant=*/true, |           /*isConstant=*/true, LLVM::Linkage::Internal, name, | ||||||
|           LLVM::Linkage::Internal, name, krnlGlobalOp.value().getValue()); |           krnlGlobalOp.value().getValue()); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // Some frequently used types.
 |     // Some frequently used types.
 | ||||||
|  |  | ||||||
|  | @ -75,7 +75,7 @@ public: | ||||||
|     OwningRewritePatternList patterns; |     OwningRewritePatternList patterns; | ||||||
|     auto *context = &getContext(); |     auto *context = &getContext(); | ||||||
|     ConstantOp::getCanonicalizationPatterns(patterns, context); |     ConstantOp::getCanonicalizationPatterns(patterns, context); | ||||||
|       applyPatternsAndFoldGreedily(f, patterns); |     applyPatternsAndFoldGreedily(f, patterns); | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
| } // end anonymous namespace
 | } // end anonymous namespace
 | ||||||
|  |  | ||||||
|  | @ -12,35 +12,35 @@ | ||||||
| #include "mlir/IR/Matchers.h" | #include "mlir/IR/Matchers.h" | ||||||
| #include "mlir/IR/PatternMatch.h" | #include "mlir/IR/PatternMatch.h" | ||||||
| 
 | 
 | ||||||
| #include <numeric> |  | ||||||
| #include "src/Dialect/ONNX/ONNXOps.hpp" | #include "src/Dialect/ONNX/ONNXOps.hpp" | ||||||
|  | #include <numeric> | ||||||
| 
 | 
 | ||||||
| using namespace mlir; | using namespace mlir; | ||||||
| 
 | 
 | ||||||
| namespace { | namespace { | ||||||
| /// Include the patterns defined in the Declarative Rewrite framework.
 | /// Include the patterns defined in the Declarative Rewrite framework.
 | ||||||
| #include "src/Transform/ONNX/ONNXCombine.inc" | #include "src/Transform/ONNX/ONNXCombine.inc" | ||||||
| }  // end anonymous namespace
 | } // end anonymous namespace
 | ||||||
| 
 | 
 | ||||||
| /// Register optimization patterns as "canonicalization" patterns
 | /// Register optimization patterns as "canonicalization" patterns
 | ||||||
| /// on the ONNXMatMultOp.
 | /// on the ONNXMatMultOp.
 | ||||||
| void ONNXAddOp::getCanonicalizationPatterns( | void ONNXAddOp::getCanonicalizationPatterns( | ||||||
|     OwningRewritePatternList& results, MLIRContext* context) { |     OwningRewritePatternList &results, MLIRContext *context) { | ||||||
|   results.insert<MulAddToGemmOptPattern>(context); |   results.insert<MulAddToGemmOptPattern>(context); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void ONNXGemmOp::getCanonicalizationPatterns( | void ONNXGemmOp::getCanonicalizationPatterns( | ||||||
|         OwningRewritePatternList& results, MLIRContext* context) { |     OwningRewritePatternList &results, MLIRContext *context) { | ||||||
|     results.insert<FuseGemmFollowedByAddition>(context); |   results.insert<FuseGemmFollowedByAddition>(context); | ||||||
| } | } | ||||||
| /// on the ONNXIdentityOp.
 | /// on the ONNXIdentityOp.
 | ||||||
| void ONNXIdentityOp::getCanonicalizationPatterns( | void ONNXIdentityOp::getCanonicalizationPatterns( | ||||||
|     OwningRewritePatternList& results, MLIRContext* context) { |     OwningRewritePatternList &results, MLIRContext *context) { | ||||||
|   results.insert<IdentityEliminationPattern>(context); |   results.insert<IdentityEliminationPattern>(context); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| ///on the ONNXPadConstantValueOp.
 | /// on the ONNXPadConstantValueOp.
 | ||||||
| void ONNXPadConstantValueOp::getCanonicalizationPatterns( | void ONNXPadConstantValueOp::getCanonicalizationPatterns( | ||||||
|     OwningRewritePatternList& result, MLIRContext* context) { |     OwningRewritePatternList &result, MLIRContext *context) { | ||||||
|   result.insert<ConstantPadPattern>(context); |   result.insert<ConstantPadPattern>(context); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -27,7 +27,8 @@ namespace { | ||||||
| /// Include the patterns defined in the Declarative Rewrite framework.
 | /// Include the patterns defined in the Declarative Rewrite framework.
 | ||||||
| #include "src/Transform/ONNX/ONNXDecompose.inc" | #include "src/Transform/ONNX/ONNXDecompose.inc" | ||||||
| 
 | 
 | ||||||
| struct DecomposeONNXToONNXPass : public PassWrapper<DecomposeONNXToONNXPass, FunctionPass> { | struct DecomposeONNXToONNXPass | ||||||
|  |     : public PassWrapper<DecomposeONNXToONNXPass, FunctionPass> { | ||||||
|   void runOnFunction() final; |   void runOnFunction() final; | ||||||
| }; | }; | ||||||
| } // end anonymous namespace.
 | } // end anonymous namespace.
 | ||||||
|  |  | ||||||
|  | @ -9,10 +9,10 @@ | ||||||
| //
 | //
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
| 
 | 
 | ||||||
|  | #include "mlir/IR/StandardTypes.h" | ||||||
| #include "mlir/Pass/Pass.h" | #include "mlir/Pass/Pass.h" | ||||||
| #include "llvm/ADT/SmallPtrSet.h" | #include "llvm/ADT/SmallPtrSet.h" | ||||||
| #include "llvm/Support/raw_ostream.h" | #include "llvm/Support/raw_ostream.h" | ||||||
| #include "mlir/IR/StandardTypes.h" |  | ||||||
| 
 | 
 | ||||||
| #include "src/Interface/ShapeInferenceInterface.hpp" | #include "src/Interface/ShapeInferenceInterface.hpp" | ||||||
| #include "src/Pass/Passes.hpp" | #include "src/Pass/Passes.hpp" | ||||||
|  | @ -25,7 +25,8 @@ namespace { | ||||||
|  *  candidate operations and propagating the shape information until the list |  *  candidate operations and propagating the shape information until the list | ||||||
|  *  of operations is empty [credit MLIR authors]. |  *  of operations is empty [credit MLIR authors]. | ||||||
|  */ |  */ | ||||||
| class ShapeInferencePass : public mlir::PassWrapper<ShapeInferencePass, mlir::FunctionPass> { | class ShapeInferencePass | ||||||
|  |     : public mlir::PassWrapper<ShapeInferencePass, mlir::FunctionPass> { | ||||||
| public: | public: | ||||||
|   void runOnFunction() override { |   void runOnFunction() override { | ||||||
|     auto f = getFunction(); |     auto f = getFunction(); | ||||||
|  | @ -63,8 +64,7 @@ public: | ||||||
| 
 | 
 | ||||||
|     if (auto terminator_op = f.getBody().back().getTerminator()) { |     if (auto terminator_op = f.getBody().back().getTerminator()) { | ||||||
|       auto results = terminator_op->getOperandTypes(); |       auto results = terminator_op->getOperandTypes(); | ||||||
|       f.setType(FunctionType::get( |       f.setType(FunctionType::get(f.getType().getInputs(), | ||||||
|           f.getType().getInputs(), |  | ||||||
|           std::vector<Type>(results.begin(), results.end()), f.getContext())); |           std::vector<Type>(results.begin(), results.end()), f.getContext())); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  | @ -146,5 +146,5 @@ std::unique_ptr<mlir::Pass> mlir::createShapeInferencePass() { | ||||||
|   return std::make_unique<ShapeInferencePass>(); |   return std::make_unique<ShapeInferencePass>(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| static PassRegistration<ShapeInferencePass> | static PassRegistration<ShapeInferencePass> pass( | ||||||
|     pass("shape-inference", "Shape inference for frontend dialects."); |     "shape-inference", "Shape inference for frontend dialects."); | ||||||
|  |  | ||||||
							
								
								
									
										24
									
								
								src/main.cpp
								
								
								
								
							
							
						
						
									
										24
									
								
								src/main.cpp
								
								
								
								
							|  | @ -14,30 +14,30 @@ using namespace onnx_mlir; | ||||||
| int main(int argc, char *argv[]) { | int main(int argc, char *argv[]) { | ||||||
|   registerDialects(); |   registerDialects(); | ||||||
| 
 | 
 | ||||||
|   llvm::cl::OptionCategory OnnxMlirOptions("ONNX MLIR Options", |   llvm::cl::OptionCategory OnnxMlirOptions( | ||||||
|                                        "These are frontend options."); |       "ONNX MLIR Options", "These are frontend options."); | ||||||
|   llvm::cl::opt<string> inputFilename( |   llvm::cl::opt<string> inputFilename(llvm::cl::Positional, | ||||||
|       llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"), |       llvm::cl::desc("<input file>"), llvm::cl::init("-"), | ||||||
|       llvm::cl::cat(OnnxMlirOptions)); |       llvm::cl::cat(OnnxMlirOptions)); | ||||||
| 
 | 
 | ||||||
|   llvm::cl::opt<EmissionTargetType> emissionTarget( |   llvm::cl::opt<EmissionTargetType> emissionTarget( | ||||||
|       llvm::cl::desc("Choose target to emit:"), |       llvm::cl::desc("Choose target to emit:"), | ||||||
|       llvm::cl::values( |       llvm::cl::values( | ||||||
|           clEnumVal(EmitONNXBasic, |           clEnumVal(EmitONNXBasic, | ||||||
|                     "Ingest ONNX and emit the basic ONNX operations without" |               "Ingest ONNX and emit the basic ONNX operations without" | ||||||
|                     "inferred shapes."), |               "inferred shapes."), | ||||||
|           clEnumVal(EmitONNXIR, |           clEnumVal( | ||||||
|                     "Ingest ONNX and emit corresponding ONNX dialect."), |               EmitONNXIR, "Ingest ONNX and emit corresponding ONNX dialect."), | ||||||
|           clEnumVal(EmitMLIR, |           clEnumVal( | ||||||
|                     "Lower model to MLIR built-in transformation dialect."), |               EmitMLIR, "Lower model to MLIR built-in transformation dialect."), | ||||||
|           clEnumVal(EmitLLVMIR, "Lower model to LLVM IR (LLVM dialect)."), |           clEnumVal(EmitLLVMIR, "Lower model to LLVM IR (LLVM dialect)."), | ||||||
|           clEnumVal(EmitLLVMBC, "Lower model to LLVM IR and emit (to file) " |           clEnumVal(EmitLLVMBC, "Lower model to LLVM IR and emit (to file) " | ||||||
|                                 "LLVM bitcode for model.")), |                                 "LLVM bitcode for model.")), | ||||||
|       llvm::cl::init(EmitLLVMBC), llvm::cl::cat(OnnxMlirOptions)); |       llvm::cl::init(EmitLLVMBC), llvm::cl::cat(OnnxMlirOptions)); | ||||||
| 
 | 
 | ||||||
|   llvm::cl::HideUnrelatedOptions(OnnxMlirOptions); |   llvm::cl::HideUnrelatedOptions(OnnxMlirOptions); | ||||||
|   llvm::cl::ParseCommandLineOptions(argc, argv, |   llvm::cl::ParseCommandLineOptions( | ||||||
|                                     "ONNX MLIR modular optimizer driver\n"); |       argc, argv, "ONNX MLIR modular optimizer driver\n"); | ||||||
| 
 | 
 | ||||||
|   mlir::MLIRContext context; |   mlir::MLIRContext context; | ||||||
|   mlir::OwningModuleRef module; |   mlir::OwningModuleRef module; | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue