[NFC] Set up clang-format Github Action (#119)

* Run clang-format on all source code.

* Add Clang-Format Github Action.

* Apply patch produced by Clang-Format Bot.

* nit.

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Tian Jin 2020-05-13 22:37:51 +08:00 committed by GitHub
parent 24343177b8
commit 7f2bffb27d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 272 additions and 260 deletions

40
.github/workflows/main.yml vendored Normal file
View File

@ -0,0 +1,40 @@
# This is a basic workflow to help you get started with Actions
name: Clang-Format Bot
# Controls when the action will run. Triggers the workflow on push or pull request
# events but only for the master branch
on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
# A workflow run is made up of one or more jobs that can run sequentially or in parallel
jobs:
# This workflow contains a single job called "build"
build:
# The type of runner that the job will run on
runs-on: ubuntu-latest
# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2
- name: clang-format lint
uses: DoozyX/clang-format-lint-action@v0.5
with:
# Source folder to check formatting
source: ./src
# Version of clang-format
clangFormatVersion: 9 # optional, default is 9
# Runs a single command using the runners shell
- name: Run a one-line script
run: echo Hello, world!
# Runs a set of commands using the runners shell
- name: Run a multi-line script
run: |
echo Add other actions to build,
echo test, and deploy your project.

View File

@ -12,8 +12,8 @@
namespace onnx_mlir { namespace onnx_mlir {
void replaceAll(std::string &str, const std::string &from, void replaceAll(
const std::string &to) { std::string &str, const std::string &from, const std::string &to) {
if (from.empty()) if (from.empty())
return; return;
size_t start_pos = 0; size_t start_pos = 0;
@ -121,7 +121,6 @@ void InitializedTensorMapping::AddMapping(
nameToInitializedTensor.emplace(name, tensor); nameToInitializedTensor.emplace(name, tensor);
} }
bool InitializedTensorMapping::ContainKey(std::string name) { bool InitializedTensorMapping::ContainKey(std::string name) {
return nameToInitializedTensor.count(name) != 0; return nameToInitializedTensor.count(name) != 0;
} }
@ -132,8 +131,8 @@ mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor(
onnx::TensorProto initializer = GetInitializedTensor(name); onnx::TensorProto initializer = GetInitializedTensor(name);
// Tensor dimensions. // Tensor dimensions.
llvm::ArrayRef<int64_t> tensorDims(initializer.dims().data(), llvm::ArrayRef<int64_t> tensorDims(
initializer.dims().size()); initializer.dims().data(), initializer.dims().size());
// Emit ConstantOp and record the mapping between the input and // Emit ConstantOp and record the mapping between the input and
// the constant value. // the constant value.
@ -142,33 +141,32 @@ mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor(
mlir::Type elementType; mlir::Type elementType;
mlir::ShapedType tensorType; mlir::ShapedType tensorType;
switch (initializer.data_type()) { switch (initializer.data_type()) {
case (onnx::TensorProto::FLOAT): { case (onnx::TensorProto::FLOAT): {
const auto& arrayAttrInitializer = const auto &arrayAttrInitializer = CreateArrayAttribute<float>(initializer);
CreateArrayAttribute<float>(initializer); elementType = builder.getF32Type();
elementType = builder.getF32Type(); tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
tensorType = mlir::RankedTensorType::get(tensorDims, elementType); constantDenseAttribute = mlir::DenseElementsAttr::get(
constantDenseAttribute = mlir::DenseElementsAttr::get( tensorType, llvm::makeArrayRef(arrayAttrInitializer));
tensorType, llvm::makeArrayRef(arrayAttrInitializer)); break;
break; }
} case (onnx::TensorProto::INT32): {
case (onnx::TensorProto::INT32): { const auto &arrayAttrInitializer =
const auto& arrayAttrInitializer = CreateArrayAttribute<int32_t>(initializer);
CreateArrayAttribute<int32_t>(initializer); elementType = builder.getIntegerType(32);
elementType = builder.getIntegerType(32); tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
tensorType = mlir::RankedTensorType::get(tensorDims, elementType); constantDenseAttribute = mlir::DenseElementsAttr::get(
constantDenseAttribute = mlir::DenseElementsAttr::get( tensorType, llvm::makeArrayRef(arrayAttrInitializer));
tensorType, llvm::makeArrayRef(arrayAttrInitializer)); break;
break; }
} case (onnx::TensorProto::INT64): {
case (onnx::TensorProto::INT64): { const auto &arrayAttrInitializer =
const auto& arrayAttrInitializer = CreateArrayAttribute<int64_t>(initializer);
CreateArrayAttribute<int64_t>(initializer); elementType = builder.getIntegerType(64);
elementType = builder.getIntegerType(64); tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
tensorType = mlir::RankedTensorType::get(tensorDims, elementType); constantDenseAttribute = mlir::DenseElementsAttr::get(
constantDenseAttribute = mlir::DenseElementsAttr::get( tensorType, llvm::makeArrayRef(arrayAttrInitializer));
tensorType, llvm::makeArrayRef(arrayAttrInitializer)); break;
break; }
}
} }
// Create ConstantOp for dense array. // Create ConstantOp for dense array.

View File

@ -20,8 +20,8 @@
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h" #include "mlir/IR/Function.h"
#include "mlir/IR/Location.h" #include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h" #include "mlir/IR/Module.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
@ -37,11 +37,12 @@
#endif #endif
#include "onnx/onnx_pb.h" #include "onnx/onnx_pb.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir { namespace onnx_mlir {
void replaceAll(std::string &str, const std::string &from, void replaceAll(
const std::string &to); std::string &str, const std::string &from, const std::string &to);
std::string legalize_name(std::string name); std::string legalize_name(std::string name);
@ -86,14 +87,14 @@ struct InitializedTensorMapping {
// This will allow the propagation of shape information passed in as an // This will allow the propagation of shape information passed in as an
// argument to operations such as Reshape and will enable other // argument to operations such as Reshape and will enable other
// optimizations such as constant folding. // optimizations such as constant folding.
mlir::Value EmitInitializerForInputTensor(mlir::Location loc, mlir::Value EmitInitializerForInputTensor(
mlir::OpBuilder &builder, std::string name); mlir::Location loc, mlir::OpBuilder &builder, std::string name);
// Get initialized tensor. // Get initialized tensor.
onnx::TensorProto& GetInitializedTensor(std::string name) { onnx::TensorProto &GetInitializedTensor(std::string name) {
assert(nameToInitializedTensor.find(name) != assert(
nameToInitializedTensor.end() && nameToInitializedTensor.find(name) != nameToInitializedTensor.end() &&
"Tensor initializer not found"); "Tensor initializer not found");
return nameToInitializedTensor.at(name); return nameToInitializedTensor.at(name);
} }

View File

@ -127,8 +127,8 @@ private:
* @param input onnx input tensor ValueInfoProto. * @param input onnx input tensor ValueInfoProto.
* @param symbol mlir input argument. * @param symbol mlir input argument.
*/ */
void ImportInputTensorSymbol(const onnx::ValueInfoProto &input, void ImportInputTensorSymbol(
mlir::Value symbol) { const onnx::ValueInfoProto &input, mlir::Value symbol) {
auto input_tensor_legalized_name = legalize_name(input.name()); auto input_tensor_legalized_name = legalize_name(input.name());
assert(!frontend_symbols_.ContainKey(input_tensor_legalized_name) && assert(!frontend_symbols_.ContainKey(input_tensor_legalized_name) &&
"Found duplicate legalized input tensor names."); "Found duplicate legalized input tensor names.");
@ -136,8 +136,7 @@ private:
} }
typedef bstd::variant<int64_t, std::vector<int64_t>, float, typedef bstd::variant<int64_t, std::vector<int64_t>, float,
std::vector<float>, std::string, std::vector<float>, std::string, std::vector<std::string>>
std::vector<std::string>>
AttrValueType; AttrValueType;
struct ONNXAttrVisitor { struct ONNXAttrVisitor {
@ -213,8 +212,8 @@ private:
llvm_unreachable("Failed to convert attribute proto to name/value pair"); llvm_unreachable("Failed to convert attribute proto to name/value pair");
} }
std::vector<mlir::NamedAttribute> std::vector<mlir::NamedAttribute> ImportNodeAttributes(
ImportNodeAttributes(const onnx::NodeProto &node) { const onnx::NodeProto &node) {
std::vector<mlir::NamedAttribute> attributes; std::vector<mlir::NamedAttribute> attributes;
for (int i = 0; i < node.attribute_size(); ++i) { for (int i = 0; i < node.attribute_size(); ++i) {
auto attr = node.attribute(i); auto attr = node.attribute(i);
@ -281,26 +280,25 @@ private:
// TODO: Handle optional inputs. // TODO: Handle optional inputs.
auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes); auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
for (int i = 0; i < node.output().size(); i++) { for (int i = 0; i < node.output().size(); i++) {
frontend_symbols_.AddMapping(legalize_name(node.output()[i]), frontend_symbols_.AddMapping(
*(op.getODSResults(i).begin())); legalize_name(node.output()[i]), *(op.getODSResults(i).begin()));
} }
} }
template <typename T> template <typename T>
void buildOperation(const onnx::NodeProto &node, void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1,
int expectedNumOperands = -1, int expectedNumResults = -1) {
int expectedNumResults = -1) {
std::vector<mlir::Value> inputs; std::vector<mlir::Value> inputs;
for (const auto &item : node.input()) for (const auto &item : node.input())
if (initializedTensors.ContainKey(legalize_name(item))) { if (initializedTensors.ContainKey(legalize_name(item))) {
inputs.push_back(initializedTensors.EmitInitializerForInputTensor( inputs.push_back(initializedTensors.EmitInitializerForInputTensor(
UnknownLoc(), builder_, legalize_name(item))); UnknownLoc(), builder_, legalize_name(item)));
} else if (frontend_symbols_.ContainKey(legalize_name(item))) { } else if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
} }
buildOutputAndOperation<T>(node, inputs, expectedNumOperands, buildOutputAndOperation<T>(
expectedNumResults); node, inputs, expectedNumOperands, expectedNumResults);
} }
void ImportNodeReshape(onnx::NodeProto node, int nIn, int nOut) { void ImportNodeReshape(onnx::NodeProto node, int nIn, int nOut) {
@ -310,9 +308,8 @@ private:
item = node.input()[i]; item = node.input()[i];
// For the second argument, check if there exists an initializer. // For the second argument, check if there exists an initializer.
if (initializedTensors.ContainKey(legalize_name(item))) { if (initializedTensors.ContainKey(legalize_name(item))) {
inputs.push_back( inputs.push_back(initializedTensors.EmitInitializerForInputTensor(
initializedTensors.EmitInitializerForInputTensor( UnknownLoc(), builder_, legalize_name(item)));
UnknownLoc(), builder_, legalize_name(item)));
} else if (frontend_symbols_.ContainKey(legalize_name(item))) { } else if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
} }
@ -372,7 +369,6 @@ private:
#if INCLUDE_ONNX_ML == 1 #if INCLUDE_ONNX_ML == 1
#include "src/Builder/MLOpBuildTable.inc" #include "src/Builder/MLOpBuildTable.inc"
#endif #endif
} }
/*! /*!
@ -388,8 +384,8 @@ private:
* output tensor. * output tensor.
*/ */
void ImportOutputTensor(const onnx::ValueInfoProto &output, void ImportOutputTensor(const onnx::ValueInfoProto &output,
llvm::SmallVectorImpl<mlir::Type> &ret_types, llvm::SmallVectorImpl<mlir::Type> &ret_types,
llvm::SmallVectorImpl<mlir::Value> &ret_vals) { llvm::SmallVectorImpl<mlir::Value> &ret_vals) {
auto output_tensor_legalized_name = legalize_name(output.name()); auto output_tensor_legalized_name = legalize_name(output.name());
assert(frontend_symbols_.ContainKey(output_tensor_legalized_name) && assert(frontend_symbols_.ContainKey(output_tensor_legalized_name) &&
"Output tensor not found"); "Output tensor not found");
@ -400,8 +396,8 @@ private:
ret_vals.push_back(tensor_val); ret_vals.push_back(tensor_val);
} }
void ImportGraph(const onnx::GraphProto &graph, void ImportGraph(
const std::string &name = "main_graph") { const onnx::GraphProto &graph, const std::string &name = "main_graph") {
// Maintain a mapping between the parameter and its initializer. // Maintain a mapping between the parameter and its initializer.
for (auto initializer : graph.initializer()) { for (auto initializer : graph.initializer()) {
auto name = initializer.name(); auto name = initializer.name();
@ -426,8 +422,7 @@ private:
// Emit the entry point operation which specifies the number of user // Emit the entry point operation which specifies the number of user
// inputs and outputs. // inputs and outputs.
auto entryPoint = mlir::ONNXEntryPointOp::create( auto entryPoint = mlir::ONNXEntryPointOp::create(UnknownLoc(), mainFunc,
UnknownLoc(), mainFunc,
/*numInputs=*/graph.input().size() - graph.initializer().size(), /*numInputs=*/graph.input().size() - graph.initializer().size(),
/*numOutputs=*/graph.output().size()); /*numOutputs=*/graph.output().size());
@ -454,8 +449,8 @@ private:
// Create a NoneTyped constant to be used for optional operation inputs // Create a NoneTyped constant to be used for optional operation inputs
// which are not used. // which are not used.
none_ = builder_.create<mlir::ConstantOp>(UnknownLoc(), none_ =
builder_.getUnitAttr()); builder_.create<mlir::ConstantOp>(UnknownLoc(), builder_.getUnitAttr());
// Import nodes in the graph. // Import nodes in the graph.
for (const auto &item : graph.node()) { for (const auto &item : graph.node()) {
@ -483,8 +478,7 @@ private:
namespace onnx_mlir { namespace onnx_mlir {
void ImportFrontendModelFile(std::string model_fname, void ImportFrontendModelFile(std::string model_fname,
mlir::MLIRContext &context, mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
mlir::OwningModuleRef &module) {
onnx::ModelProto model; onnx::ModelProto model;
std::fstream input(model_fname, std::ios::in | std::ios::binary); std::fstream input(model_fname, std::ios::in | std::ios::binary);

View File

@ -36,11 +36,10 @@ namespace onnx_mlir {
* @return MLIR::module generated for the ONNX model. * @return MLIR::module generated for the ONNX model.
*/ */
void ImportFrontendModelFile(std::string model_fname, void ImportFrontendModelFile(std::string model_fname,
mlir::MLIRContext &context, mlir::MLIRContext &context, mlir::OwningModuleRef &module);
mlir::OwningModuleRef &module);
/*! /*!
* TODO: Import models into other extension dialects that cover the * TODO: Import models into other extension dialects that cover the
* operations specific to other frameworks such as Tensorflow or Pytorch. * operations specific to other frameworks such as Tensorflow or Pytorch.
*/ */
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@ -1,4 +1,5 @@
//====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering --------===// //====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering
//--------===//
// //
// Copyright 2019 The IBM Research Authors. // Copyright 2019 The IBM Research Authors.
// //
@ -21,10 +22,9 @@ class ONNXEntryPointLowering : public OpRewritePattern<ONNXEntryPointOp> {
public: public:
using OpRewritePattern<ONNXEntryPointOp>::OpRewritePattern; using OpRewritePattern<ONNXEntryPointOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ONNXEntryPointOp op, LogicalResult matchAndRewrite(
PatternRewriter &rewriter) const override { ONNXEntryPointOp op, PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<KrnlEntryPointOp>( rewriter.replaceOpWithNewOp<KrnlEntryPointOp>(op,
op,
op.getAttrOfType<SymbolRefAttr>( op.getAttrOfType<SymbolRefAttr>(
ONNXEntryPointOp::getEntryPointFuncAttrName()), ONNXEntryPointOp::getEntryPointFuncAttrName()),
op.getAttrOfType<IntegerAttr>(ONNXEntryPointOp::getNumInputsAttrName()), op.getAttrOfType<IntegerAttr>(ONNXEntryPointOp::getNumInputsAttrName()),
@ -55,8 +55,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
// We define the specific operations, or dialects, that are legal targets for // We define the specific operations, or dialects, that are legal targets for
// this lowering. // this lowering.
target target.addLegalDialect<KrnlOpsDialect, AffineDialect, StandardOpsDialect>();
.addLegalDialect<KrnlOpsDialect, AffineDialect, StandardOpsDialect>();
// TODO: enable this once more ops are supported. // TODO: enable this once more ops are supported.
// We also define the ONNX dialect as Illegal so that the conversion will fail // We also define the ONNX dialect as Illegal so that the conversion will fail
@ -81,8 +80,8 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
// Type conversion for function signatures. // Type conversion for function signatures.
// Call MLIR FuncOp signature conversion when result type is // Call MLIR FuncOp signature conversion when result type is
// a ranked tensor. // a ranked tensor.
populateFuncOpTypeConversionPattern(patterns, &getContext(), populateFuncOpTypeConversionPattern(
tensor_to_memref_converter); patterns, &getContext(), tensor_to_memref_converter);
// Frontend operation lowering. // Frontend operation lowering.
// Math // Math
@ -119,5 +118,5 @@ std::unique_ptr<Pass> mlir::createLowerToKrnlPass() {
return std::make_unique<FrontendToKrnlLoweringPass>(); return std::make_unique<FrontendToKrnlLoweringPass>();
} }
static PassRegistration<FrontendToKrnlLoweringPass> static PassRegistration<FrontendToKrnlLoweringPass> pass(
pass("lower-frontend", "Lower frontend ops to Krnl dialect."); "lower-frontend", "Lower frontend ops to Krnl dialect.");

View File

@ -499,9 +499,8 @@ template <typename ElementwiseUnaryOp>
struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
ONNXElementwiseUnaryOpLowering(MLIRContext *ctx) ONNXElementwiseUnaryOpLowering(MLIRContext *ctx)
: ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {} : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {}
LogicalResult LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final {
ConversionPatternRewriter &rewriter) const final {
// TODO: Check that the types are valid. // TODO: Check that the types are valid.
// An element-wise unary operation must have all operands and the result of // An element-wise unary operation must have all operands and the result of
// the same type. This should have been verified by the verifier. // the same type. This should have been verified by the verifier.
@ -566,9 +565,8 @@ template <typename ElementwiseVariadicOp>
struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
ONNXElementwiseVariadicOpLowering(MLIRContext *ctx) ONNXElementwiseVariadicOpLowering(MLIRContext *ctx)
: ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {} : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {}
LogicalResult LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final {
ConversionPatternRewriter &rewriter) const final {
// TODO: Check that the types are valid. // TODO: Check that the types are valid.
// An element-wise variadic operation must have all operands and the result // An element-wise variadic operation must have all operands and the result
// of the same type. This should have been verified by the verifier. // of the same type. This should have been verified by the verifier.

View File

@ -1,4 +1,5 @@
//===----------------- Gemm.cpp - Lowering Gemm Op -------------------------===// //===----------------- Gemm.cpp - Lowering Gemm Op
//-------------------------===//
// //
// Copyright 2019 The IBM Research Authors. // Copyright 2019 The IBM Research Authors.
// //
@ -17,9 +18,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
ONNXGemmOpLowering(MLIRContext *ctx) ONNXGemmOpLowering(MLIRContext *ctx)
: ConversionPattern(GemmOp::getOperationName(), 1, ctx) {} : ConversionPattern(GemmOp::getOperationName(), 1, ctx) {}
LogicalResult LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final {
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc(); auto loc = op->getLoc();
bool hasBias = !op->getOperand(2).getType().isa<NoneType>(); bool hasBias = !op->getOperand(2).getType().isa<NoneType>();
@ -32,12 +32,10 @@ struct ONNXGemmOpLowering : public ConversionPattern {
auto memRefType = convertToMemRefType(*op->result_type_begin()); auto memRefType = convertToMemRefType(*op->result_type_begin());
auto alphaAttr = auto alphaAttr = FloatAttr::get(memRefType.getElementType(),
FloatAttr::get(memRefType.getElementType(), llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat());
llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat()); auto betaAttr = FloatAttr::get(memRefType.getElementType(),
auto betaAttr = llvm::dyn_cast<GemmOp>(op).beta().convertToFloat());
FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<GemmOp>(op).beta().convertToFloat());
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr); auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto beta = rewriter.create<ConstantOp>(loc, betaAttr); auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
@ -101,8 +99,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
optimizedReductionLoops.reserve(1); optimizedReductionLoops.reserve(1);
reductionLoops.push_back(originalLoops[2]); reductionLoops.push_back(originalLoops[2]);
optimizedReductionLoops.push_back(optimizedLoops[2]); optimizedReductionLoops.push_back(optimizedLoops[2]);
KrnlIterateOperandPack reductionPack(rewriter, reductionLoops, KrnlIterateOperandPack reductionPack(
optimizedReductionLoops); rewriter, reductionLoops, optimizedReductionLoops);
// Induction variable for the reduction dimension // Induction variable for the reduction dimension
// Try to find and use a static value from A or B first. // Try to find and use a static value from A or B first.
// If it failed then use a dynamic value. // If it failed then use a dynamic value.
@ -167,8 +165,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs); auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB); auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
if (hasBias) { if (hasBias) {
auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C, auto loopCIVs = getLoopIVsForBroadcasting(
broadcastedDimInfo); loc, rewriter, loopMNIVs, C, broadcastedDimInfo);
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs); auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
auto betaC = rewriter.create<MulFOp>(loc, beta, loadedC); auto betaC = rewriter.create<MulFOp>(loc, beta, loadedC);
auto Y = rewriter.create<AddFOp>(loc, alphaAB, betaC); auto Y = rewriter.create<AddFOp>(loc, alphaAB, betaC);
@ -214,7 +212,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
} }
}; };
void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns, void populateLoweringONNXGemmOpPattern(
MLIRContext *ctx) { OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx); patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
} }

View File

@ -16,9 +16,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
ONNXMatMulOpLowering(MLIRContext *ctx) ONNXMatMulOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {} : ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {}
LogicalResult LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final {
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc(); auto loc = op->getLoc();
ONNXMatMulOpOperandAdaptor operandAdaptor(operands); ONNXMatMulOpOperandAdaptor operandAdaptor(operands);
@ -119,8 +118,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
// Define loops for batch dimensions. // Define loops for batch dimensions.
std::vector<Value> originalLoops; std::vector<Value> originalLoops;
std::vector<Value> optimizedLoops; std::vector<Value> optimizedLoops;
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, Block *optimizationBlock = defineLoops(
optimizedLoops, memRefShape.size()); rewriter, loc, originalLoops, optimizedLoops, memRefShape.size());
// Outer KrnlIterateOp // Outer KrnlIterateOp
SmallVector<Value, 4> loopBatchIVs; SmallVector<Value, 4> loopBatchIVs;
@ -139,8 +138,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
outerLoops.push_back(originalLoops[i]); outerLoops.push_back(originalLoops[i]);
optimizedOuterLoops.push_back(optimizedLoops[i]); optimizedOuterLoops.push_back(optimizedLoops[i]);
} }
KrnlIterateOperandPack outerPack(rewriter, outerLoops, KrnlIterateOperandPack outerPack(
optimizedOuterLoops); rewriter, outerLoops, optimizedOuterLoops);
for (int i = 0; i < batchAxes.size(); ++i) { for (int i = 0; i < batchAxes.size(); ++i) {
addDimensionToPack(rewriter, loc, outerPack, alloc, i); addDimensionToPack(rewriter, loc, outerPack, alloc, i);
} }
@ -176,11 +175,11 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
optimizedMatmulLoops.emplace_back( optimizedMatmulLoops.emplace_back(
optimizedLoops[memRefShape.size() - i]); optimizedLoops[memRefShape.size() - i]);
} }
KrnlIterateOperandPack matmulPack(rewriter, matmulLoops, KrnlIterateOperandPack matmulPack(
optimizedMatmulLoops); rewriter, matmulLoops, optimizedMatmulLoops);
for (int i = 2; i > 0; --i) { for (int i = 2; i > 0; --i) {
addDimensionToPack(rewriter, loc, matmulPack, alloc, addDimensionToPack(
memRefShape.size() - i); rewriter, loc, matmulPack, alloc, memRefShape.size() - i);
} }
matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack); matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack);
} else { } else {
@ -190,10 +189,10 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
matmulLoops.emplace_back(originalLoops[memRefShape.size() - 1]); matmulLoops.emplace_back(originalLoops[memRefShape.size() - 1]);
optimizedMatmulLoops.emplace_back( optimizedMatmulLoops.emplace_back(
optimizedLoops[memRefShape.size() - 1]); optimizedLoops[memRefShape.size() - 1]);
KrnlIterateOperandPack matmulPack(rewriter, matmulLoops, KrnlIterateOperandPack matmulPack(
optimizedMatmulLoops); rewriter, matmulLoops, optimizedMatmulLoops);
addDimensionToPack(rewriter, loc, matmulPack, alloc, addDimensionToPack(
memRefShape.size() - 1); rewriter, loc, matmulPack, alloc, memRefShape.size() - 1);
matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack); matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack);
} }
@ -230,8 +229,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
std::vector<Value> optimizedReduceLoops; std::vector<Value> optimizedReduceLoops;
Block *optimizationReduceBlock = Block *optimizationReduceBlock =
defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1); defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1);
KrnlIterateOperandPack reducePack(rewriter, reduceLoops, KrnlIterateOperandPack reducePack(
optimizedReduceLoops); rewriter, reduceLoops, optimizedReduceLoops);
addDimensionToPack(rewriter, loc, reducePack, A, AShape.size() - 1); addDimensionToPack(rewriter, loc, reducePack, A, AShape.size() - 1);
auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack); auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack);
@ -292,8 +291,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
std::vector<Value> optimizedReduceLoops; std::vector<Value> optimizedReduceLoops;
Block *optimizationReduceBlock = Block *optimizationReduceBlock =
defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1); defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1);
KrnlIterateOperandPack reducePack(rewriter, reduceLoops, KrnlIterateOperandPack reducePack(
optimizedReduceLoops); rewriter, reduceLoops, optimizedReduceLoops);
addDimensionToPack(rewriter, loc, reducePack, A, 0); addDimensionToPack(rewriter, loc, reducePack, A, 0);
auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack); auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack);

View File

@ -102,9 +102,8 @@ struct ONNXReductionOpLowering : public ConversionPattern {
ONNXReductionOpLowering(MLIRContext *ctx) ONNXReductionOpLowering(MLIRContext *ctx)
: ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {} : ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {}
LogicalResult LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final {
ConversionPatternRewriter &rewriter) const final {
/* /*
* Condition: reduction function must be associative and commutative. * Condition: reduction function must be associative and commutative.
* *

View File

@ -1,4 +1,5 @@
//===--------------- Conv.cpp - Lowering Convolution Op --------------------===// //===--------------- Conv.cpp - Lowering Convolution Op
//--------------------===//
// //
// Copyright 2019 The IBM Research Authors. // Copyright 2019 The IBM Research Authors.
// //
@ -175,14 +176,12 @@ struct ONNXConvOpLowering : public ConversionPattern {
// Emit the bias, if needed. // Emit the bias, if needed.
if (hasBias) { if (hasBias) {
auto loadResult = auto loadResult = rewriter.create<LoadOp>(loc, alloc, resultIndices);
rewriter.create<LoadOp>(loc, alloc, resultIndices);
SmallVector<Value, 4> biasIndices; SmallVector<Value, 4> biasIndices;
biasIndices.emplace_back(kernel); biasIndices.emplace_back(kernel);
auto loadBias = auto loadBias = rewriter.create<LoadOp>(loc, biasOperand, kernel);
rewriter.create<LoadOp>(loc, biasOperand, kernel); auto resultWithBias =
auto resultWithBias = rewriter.create<MulFOp>( rewriter.create<MulFOp>(loc, loadResult, loadBias);
loc, loadResult, loadBias);
// Store initializer value into output location. // Store initializer value into output location.
rewriter.create<StoreOp>(loc, resultWithBias, alloc, resultIndices); rewriter.create<StoreOp>(loc, resultWithBias, alloc, resultIndices);
} }

View File

@ -459,7 +459,8 @@ struct ONNXPoolOpLowering : public ConversionPattern {
poolDimMap = AffineMap::get(1, 5, dimExpr, rewriter.getContext()); poolDimMap = AffineMap::get(1, 5, dimExpr, rewriter.getContext());
// poolStartMap and poolEndMap // poolStartMap and poolEndMap
poolStartMap = AffineMap::get(1, 5, {start1, start2}, rewriter.getContext()); poolStartMap =
AffineMap::get(1, 5, {start1, start2}, rewriter.getContext());
poolEndMap = AffineMap::get(1, 5, {end1, end2}, rewriter.getContext()); poolEndMap = AffineMap::get(1, 5, {end1, end2}, rewriter.getContext());
} }

View File

@ -36,9 +36,7 @@ MemRefType convertToMemRefType(Type type) {
/// Insert an allocation and deallocation for the given MemRefType. /// Insert an allocation and deallocation for the given MemRefType.
Value insertAllocAndDealloc(MemRefType type, Location loc, Value insertAllocAndDealloc(MemRefType type, Location loc,
PatternRewriter &rewriter, PatternRewriter &rewriter, bool insertDealloc, ArrayRef<Value> operands) {
bool insertDealloc,
ArrayRef<Value> operands) {
// Put together alloc operands for any dynamic dimensions of the memref. // Put together alloc operands for any dynamic dimensions of the memref.
AllocOp alloc; AllocOp alloc;
if (!operands.empty()) { if (!operands.empty()) {
@ -64,10 +62,10 @@ Value insertAllocAndDealloc(MemRefType type, Location loc,
auto operandDim = auto operandDim =
rewriter.create<DimOp>(loc, operands[i], operandDimIdx); rewriter.create<DimOp>(loc, operands[i], operandDimIdx);
if (maxDim) { if (maxDim) {
auto maxCondition = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, auto maxCondition = rewriter.create<CmpIOp>(
operandDim, maxDim); loc, CmpIPredicate::sgt, operandDim, maxDim);
maxDim = rewriter.create<SelectOp>(loc, maxCondition, operandDim, maxDim = rewriter.create<SelectOp>(
maxDim); loc, maxCondition, operandDim, maxDim);
} else { } else {
maxDim = operandDim; maxDim = operandDim;
} }
@ -122,8 +120,8 @@ bool checkInsertDealloc(Operation *currentOp, int resultIndex) {
// Create a mapping from result type's dimensions to input type's dimensions, // Create a mapping from result type's dimensions to input type's dimensions,
// given that the result type is the result of a reduction op over the input // given that the result type is the result of a reduction op over the input
// type. // type.
std::map<int64_t, int64_t> std::map<int64_t, int64_t> getReductionMapping(
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) { MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
std::map<int64_t, int64_t> OutInDimMap; std::map<int64_t, int64_t> OutInDimMap;
int64_t rank = inputTy.getRank(); int64_t rank = inputTy.getRank();
@ -152,9 +150,8 @@ getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
// Add bounds associated with the op operand to the KRNL iteration pack. // Add bounds associated with the op operand to the KRNL iteration pack.
// Dynamic dimenions are supported. // Dynamic dimenions are supported.
void addDimensionToPack(ConversionPatternRewriter &rewriter, void addDimensionToPack(ConversionPatternRewriter &rewriter, Location loc,
Location loc, KrnlIterateOperandPack &pack, KrnlIterateOperandPack &pack, Value operand, int index) {
Value operand, int index) {
auto shape = operand.getType().cast<MemRefType>().getShape(); auto shape = operand.getType().cast<MemRefType>().getShape();
if (shape[index] < 0) { if (shape[index] < 0) {
pack.pushConstantBound(0); pack.pushConstantBound(0);
@ -168,10 +165,9 @@ void addDimensionToPack(ConversionPatternRewriter &rewriter,
// Function that defines the KRNL dialect loops and their respective // Function that defines the KRNL dialect loops and their respective
// optimized version. // optimized version.
KrnlOptimizeLoopsOp KrnlOptimizeLoopsOp emitOptimizedLoops(ConversionPatternRewriter &rewriter,
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc, Location loc, std::vector<Value> &loops, std::vector<Value> &optimizedLoops,
std::vector<Value> &loops, int64_t numLoops) {
std::vector<Value> &optimizedLoops, int64_t numLoops) {
// Define loops. // Define loops.
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops); auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops);
loops.reserve(numLoops); loops.reserve(numLoops);
@ -190,9 +186,8 @@ emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
// Function that emits the loops and their optimized version. // Function that emits the loops and their optimized version.
// The function returns a reference to the inner optimization block. // The function returns a reference to the inner optimization block.
Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc, Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
std::vector<Value> &loops, std::vector<Value> &loops, std::vector<Value> &optimizedLoops,
std::vector<Value> &optimizedLoops, int64_t numLoops) {
int64_t numLoops) {
KrnlOptimizeLoopsOp optimizedLoopsOp = KrnlOptimizeLoopsOp optimizedLoopsOp =
emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops); emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops);
return &optimizedLoopsOp.region().front(); return &optimizedLoopsOp.region().front();
@ -201,10 +196,9 @@ Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
// Function which emits a basic set of loops and optimized loops // Function which emits a basic set of loops and optimized loops
// for a given operation argument. A reference to the loop optimization // for a given operation argument. A reference to the loop optimization
// block is returned in the last argument of the function. // block is returned in the last argument of the function.
void emitKrnlLoopsAndIterationForOperand( void emitKrnlLoopsAndIterationForOperand(ConversionPatternRewriter &rewriter,
ConversionPatternRewriter &rewriter, Location loc, Value operand, Location loc, Value operand, std::vector<Value> &originalLoops,
std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp, KrnlOptimizeLoopsOp &optimizedLoopsOp, KrnlIterateOp &iterateOp) {
KrnlIterateOp &iterateOp) {
// Operand shape. // Operand shape.
auto shape = operand.getType().cast<MemRefType>().getShape(); auto shape = operand.getType().cast<MemRefType>().getShape();
@ -240,9 +234,9 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
// Get run-time dimension information for unknown dimensions used for // Get run-time dimension information for unknown dimensions used for
// broadcasting. // broadcasting.
std::map<int, std::map<int, Value>> std::map<int, std::map<int, Value>> getBroadcastedDimInfo(Location loc,
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter, MemRefType memRefType,
MemRefType memRefType, ArrayRef<Value> operands) { ArrayRef<Value> operands) {
auto memRefShape = memRefType.getShape(); auto memRefShape = memRefType.getShape();
int64_t rank = memRefShape.size(); int64_t rank = memRefShape.size();
// For unknown dimensions, we need to get dimension values at runtime in // For unknown dimensions, we need to get dimension values at runtime in
@ -286,10 +280,9 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
// Extract induction variables that are used for broadcasting values of a // Extract induction variables that are used for broadcasting values of a
// given operand. // given operand.
std::vector<Value> std::vector<Value> getLoopIVsForBroadcasting(Location loc,
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter, ArrayRef<Value> loopIVs, Value operand,
ArrayRef<Value> loopIVs, Value operand, std::map<int, Value> broadcastedDims) {
std::map<int, Value> broadcastedDims) {
// `operand` must has a ranked type. This should have been checked by the // `operand` must has a ranked type. This should have been checked by the
// shape inference pass. // shape inference pass.
auto operandShape = operand.getType().cast<MemRefType>().getShape(); auto operandShape = operand.getType().cast<MemRefType>().getShape();
@ -310,8 +303,8 @@ getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
// If its value is 1, it is broadcasted dimension. // If its value is 1, it is broadcasted dimension.
// Otherwise, non-broadcasted dimension. // Otherwise, non-broadcasted dimension.
auto zero = rewriter.create<ConstantIndexOp>(loc, 0); auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], zero, auto idx = rewriter.create<SelectOp>(
loopIVs[loopIdx]); loc, broadcastedDims[dimIdx], zero, loopIVs[loopIdx]);
newLoopIVs.insert(newLoopIVs.begin(), idx); newLoopIVs.insert(newLoopIVs.begin(), idx);
} else { } else {
// Non-broadcasted dimension // Non-broadcasted dimension

View File

@ -30,7 +30,7 @@ struct ONNXConcatOpLowering : public ConversionPattern {
auto memRefType = convertToMemRefType(*op->result_type_begin()); auto memRefType = convertToMemRefType(*op->result_type_begin());
auto resultShape = memRefType.getShape(); auto resultShape = memRefType.getShape();
auto rank = resultShape.size(); auto rank = resultShape.size();
assert((axis >=0 && axis < rank) && "Concat axis out of bounds"); assert((axis >= 0 && axis < rank) && "Concat axis out of bounds");
if (hasAllConstantDimensions(memRefType)) if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);

View File

@ -17,8 +17,8 @@ struct ONNXConstantOpLowering : public ConversionPattern {
ONNXConstantOpLowering(MLIRContext *ctx) ONNXConstantOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXConstantOp::getOperationName(), 1, ctx) { : ConversionPattern(mlir::ONNXConstantOp::getOperationName(), 1, ctx) {
constantID = 0; constantID = 0;
} }
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
@ -34,12 +34,11 @@ struct ONNXConstantOpLowering : public ConversionPattern {
// Shape based computations. // Shape based computations.
auto shape = memRefType.getShape(); auto shape = memRefType.getShape();
int64_t numElements = 1; int64_t numElements = 1;
for (int i=0; i<shape.size(); ++i) for (int i = 0; i < shape.size(); ++i)
numElements *= shape[i]; numElements *= shape[i];
// Emit the constant global in Krnl dialect. // Emit the constant global in Krnl dialect.
auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc, auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc, memRefType,
memRefType,
rewriter.getI64ArrayAttr(shape), rewriter.getI64ArrayAttr(shape),
rewriter.getStringAttr("constant_" + std::to_string(constantID)), rewriter.getStringAttr("constant_" + std::to_string(constantID)),
constantOp.value().getValue()); constantOp.value().getValue());

View File

@ -24,15 +24,15 @@ enum Kinds {
} }
class LoopType : public mlir::Type::TypeBase<LoopType, mlir::Type> { class LoopType : public mlir::Type::TypeBase<LoopType, mlir::Type> {
public: public:
using Base::Base; using Base::Base;
// Support type inquiry through isa, cast and dyn_cast. // Support type inquiry through isa, cast and dyn_cast.
static bool kindof(unsigned kind) { return kind == KrnlTypes::Loop; } static bool kindof(unsigned kind) { return kind == KrnlTypes::Loop; }
// Get a unique instance of Loop type. // Get a unique instance of Loop type.
static LoopType get(mlir::MLIRContext* context) { static LoopType get(mlir::MLIRContext *context) {
return Base::get(context, KrnlTypes::Loop); return Base::get(context, KrnlTypes::Loop);
} }
}; };
} // namespace mlir } // namespace mlir

View File

@ -39,7 +39,6 @@ MLONNXOpsDialect::MLONNXOpsDialect(mlir::MLIRContext *ctx)
>(); >();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TableGen'd op method definitions // TableGen'd op method definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -19,14 +19,14 @@
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "src/Interface/ShapeInferenceInterface.hpp"
#include "src/Interface/PromotableConstOperandsOpInterface.hpp" #include "src/Interface/PromotableConstOperandsOpInterface.hpp"
#include "src/Interface/ShapeInferenceInterface.hpp"
namespace mlir { namespace mlir {
class MLONNXOpsDialect : public Dialect { class MLONNXOpsDialect : public Dialect {
public: public:
MLONNXOpsDialect(MLIRContext* context); MLONNXOpsDialect(MLIRContext *context);
/// Provide a utility accessor to the dialect namespace. This is used by /// Provide a utility accessor to the dialect namespace. This is used by
/// several utilities for casting between dialects. /// several utilities for casting between dialects.
@ -38,6 +38,6 @@ class MLONNXOpsDialect : public Dialect {
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "src/Dialect/MLONNX/MLONNXOps.hpp.inc" #include "src/Dialect/MLONNX/MLONNXOps.hpp.inc"
} // end namespace mlir } // end namespace mlir
namespace onnx_mlir {} namespace onnx_mlir {}

View File

@ -19,14 +19,14 @@
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "src/Interface/ShapeInferenceInterface.hpp"
#include "src/Interface/PromotableConstOperandsOpInterface.hpp" #include "src/Interface/PromotableConstOperandsOpInterface.hpp"
#include "src/Interface/ShapeInferenceInterface.hpp"
namespace mlir { namespace mlir {
class ONNXOpsDialect : public Dialect { class ONNXOpsDialect : public Dialect {
public: public:
ONNXOpsDialect(MLIRContext* context); ONNXOpsDialect(MLIRContext *context);
/// Provide a utility accessor to the dialect namespace. This is used by /// Provide a utility accessor to the dialect namespace. This is used by
/// several utilities for casting between dialects. /// several utilities for casting between dialects.
@ -38,6 +38,6 @@ class ONNXOpsDialect : public Dialect {
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "src/Dialect/ONNX/ONNXOps.hpp.inc" #include "src/Dialect/ONNX/ONNXOps.hpp.inc"
} // end namespace mlir } // end namespace mlir
namespace onnx_mlir {} namespace onnx_mlir {}

View File

@ -10,7 +10,6 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "src/Interface/PromotableConstOperandsOpInterface.hpp" #include "src/Interface/PromotableConstOperandsOpInterface.hpp"
using namespace mlir; using namespace mlir;
@ -20,4 +19,3 @@ using namespace mlir;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "src/Interface/PromotableConstOperandsOpInterface.cpp.inc" #include "src/Interface/PromotableConstOperandsOpInterface.cpp.inc"

View File

@ -22,4 +22,4 @@ namespace mlir {
/// Include the auto-generated declarations. /// Include the auto-generated declarations.
#include "src/Interface/PromotableConstOperandsOpInterface.hpp.inc" #include "src/Interface/PromotableConstOperandsOpInterface.hpp.inc"
} // end namespace mlir } // end namespace mlir

View File

@ -16,4 +16,4 @@ namespace mlir {
/// Include the auto-generated declarations. /// Include the auto-generated declarations.
#include "src/Interface/ShapeInference.cpp.inc" #include "src/Interface/ShapeInference.cpp.inc"
} // end namespace mlir } // end namespace mlir

View File

@ -18,4 +18,4 @@ namespace mlir {
/// Include the auto-generated declarations. /// Include the auto-generated declarations.
#include "src/Interface/ShapeInference.hpp.inc" #include "src/Interface/ShapeInference.hpp.inc"
} // end namespace mlir } // end namespace mlir

View File

@ -22,7 +22,7 @@ using namespace std;
using namespace onnx_mlir; using namespace onnx_mlir;
void LoadMLIR(string inputFilename, mlir::MLIRContext &context, void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
mlir::OwningModuleRef &module) { mlir::OwningModuleRef &module) {
// Handle '.mlir' input to the ONNX MLIR frontend. // Handle '.mlir' input to the ONNX MLIR frontend.
// The mlir format indicates that one or more of the supported // The mlir format indicates that one or more of the supported
// representations are used in the file. // representations are used in the file.
@ -46,10 +46,10 @@ void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
void EmitLLVMBitCode( void EmitLLVMBitCode(
const mlir::OwningModuleRef &module, string outputFilename) { const mlir::OwningModuleRef &module, string outputFilename) {
error_code error; error_code error;
llvm::raw_fd_ostream moduleBitcodeStream(outputFilename, error, llvm::raw_fd_ostream moduleBitcodeStream(
llvm::sys::fs::F_None); outputFilename, error, llvm::sys::fs::F_None);
llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module), llvm::WriteBitcodeToFile(
moduleBitcodeStream); *mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream);
moduleBitcodeStream.flush(); moduleBitcodeStream.flush();
} }
@ -90,7 +90,7 @@ void addKrnlToLLVMPasses(mlir::PassManager &pm) {
} }
void processInputFile(string inputFilename, EmissionTargetType emissionTarget, void processInputFile(string inputFilename, EmissionTargetType emissionTarget,
mlir::MLIRContext &context, mlir::OwningModuleRef &module) { mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
// Decide if the input file is an ONNX model or a model specified // Decide if the input file is an ONNX model or a model specified
// in MLIR. The extension of the file is the decider. // in MLIR. The extension of the file is the decider.
string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1); string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1);
@ -99,7 +99,6 @@ void processInputFile(string inputFilename, EmissionTargetType emissionTarget,
assert(inputIsONNX != inputIsMLIR && assert(inputIsONNX != inputIsMLIR &&
"Either ONNX model or MLIR file needs to be provided."); "Either ONNX model or MLIR file needs to be provided.");
if (inputIsONNX) { if (inputIsONNX) {
ImportFrontendModelFile(inputFilename, context, module); ImportFrontendModelFile(inputFilename, context, module);
} else { } else {
@ -119,7 +118,7 @@ void outputCode(
module->dump(); module->dump();
fflush(stderr); fflush(stderr);
// set modified stderr as original stderr // set modified stderr as original stderr
_dup2(stderrOrigin, _fileno( stderr )); _dup2(stderrOrigin, _fileno(stderr));
#else #else
if (fork() == 0) { if (fork() == 0) {
freopen(tempFilename.c_str(), "w", stderr); freopen(tempFilename.c_str(), "w", stderr);
@ -151,7 +150,7 @@ void emitOutputFiles(string outputBaseName, EmissionTargetType emissionTarget,
// necessary when emitting the .bc file. // necessary when emitting the .bc file.
if (emissionTarget == EmitLLVMBC) { if (emissionTarget == EmitLLVMBC) {
// Write LLVM bitcode to disk. // Write LLVM bitcode to disk.
string outputFilename = outputBaseName + ".bc"; string outputFilename = outputBaseName + ".bc";
EmitLLVMBitCode(module, outputFilename); EmitLLVMBitCode(module, outputFilename);
printf("LLVM bitcode written to %s\n", outputFilename.c_str()); printf("LLVM bitcode written to %s\n", outputFilename.c_str());
} else { } else {

View File

@ -28,9 +28,9 @@
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/InitAllDialects.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h" #include "mlir/IR/Module.h"
#include "mlir/InitAllDialects.h"
#include "mlir/Parser.h" #include "mlir/Parser.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassManager.h"
@ -46,10 +46,10 @@ enum EmissionTargetType {
}; };
void LoadMLIR(std::string inputFilename, mlir::MLIRContext &context, void LoadMLIR(std::string inputFilename, mlir::MLIRContext &context,
mlir::OwningModuleRef &module); mlir::OwningModuleRef &module);
void EmitLLVMBitCode( void EmitLLVMBitCode(
const mlir::OwningModuleRef &module, std::string outputFilename); const mlir::OwningModuleRef &module, std::string outputFilename);
void registerDialects(); void registerDialects();
@ -66,8 +66,7 @@ void processInputFile(std::string inputFilename,
mlir::OwningModuleRef &module); mlir::OwningModuleRef &module);
void outputCode( void outputCode(
mlir::OwningModuleRef &module, std::string filename, mlir::OwningModuleRef &module, std::string filename, std::string extension);
std::string extension);
void emitOutputFiles(std::string outputBaseName, void emitOutputFiles(std::string outputBaseName,
EmissionTargetType emissionTarget, mlir::MLIRContext &context, EmissionTargetType emissionTarget, mlir::MLIRContext &context,

View File

@ -38,4 +38,4 @@ std::unique_ptr<Pass> createElideConstGlobalValuePass();
/// Pass for lowering Krnl dialect to LLVM dialect. /// Pass for lowering Krnl dialect to LLVM dialect.
std::unique_ptr<Pass> createKrnlLowerToLLVMPass(); std::unique_ptr<Pass> createKrnlLowerToLLVMPass();
} // end namespace mlir } // end namespace mlir

View File

@ -1,15 +1,15 @@
enum DYN_MEMREF_DATA_TYPE { enum DYN_MEMREF_DATA_TYPE {
UNDEFINED = 0; UNDEFINED = 0;
// Basic types. // Basic types.
FLOAT = 1; // float FLOAT = 1; // float
UINT8 = 2; // uint8_t UINT8 = 2; // uint8_t
INT8 = 3; // int8_t INT8 = 3; // int8_t
UINT16 = 4; // uint16_t UINT16 = 4; // uint16_t
INT16 = 5; // int16_t INT16 = 5; // int16_t
INT32 = 6; // int32_t INT32 = 6; // int32_t
INT64 = 7; // int64_t INT64 = 7; // int64_t
STRING = 8; // string STRING = 8; // string
BOOL = 9; // bool BOOL = 9; // bool
// IEEE754 half-precision floating-point format (16 bits wide). // IEEE754 half-precision floating-point format (16 bits wide).
// This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
@ -18,8 +18,8 @@ enum DYN_MEMREF_DATA_TYPE {
DOUBLE = 11; DOUBLE = 11;
UINT32 = 12; UINT32 = 12;
UINT64 = 13; UINT64 = 13;
COMPLEX64 = 14; // complex with float32 real and imaginary components COMPLEX64 = 14; // complex with float32 real and imaginary components
COMPLEX128 = 15; // complex with float64 real and imaginary components COMPLEX128 = 15; // complex with float64 real and imaginary components
// Non-IEEE floating-point format based on IEEE754 single-precision // Non-IEEE floating-point format based on IEEE754 single-precision
// floating-point number truncated to 16 bits. // floating-point number truncated to 16 bits.

View File

@ -33,8 +33,8 @@ DynMemRef *getDynMemRef(OrderedDynMemRefDict *tensorDict, int idx) {
return tensorDict->tensorDict[tensorDict->orderedNames[idx]]; return tensorDict->tensorDict[tensorDict->orderedNames[idx]];
} }
void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx, void setDynMemRef(
DynMemRef *tensor) { OrderedDynMemRefDict *tensorDict, int idx, DynMemRef *tensor) {
if (tensorDict->orderedNames.size() <= idx) if (tensorDict->orderedNames.size() <= idx)
tensorDict->orderedNames.resize(idx + 1); tensorDict->orderedNames.resize(idx + 1);

View File

@ -39,7 +39,6 @@ typedef struct _OrderedDynMemRefDict OrderedDynMemRefDict;
extern "C" { extern "C" {
#endif #endif
// Get number of dynamic memrefs in OrderedDynMemRefDict dict. // Get number of dynamic memrefs in OrderedDynMemRefDict dict.
int numDynMemRefs(OrderedDynMemRefDict *dict); int numDynMemRefs(OrderedDynMemRefDict *dict);
@ -53,8 +52,8 @@ DynMemRef *createDynMemRef(int rank);
DynMemRef *getDynMemRef(OrderedDynMemRefDict *orderedDict, int i); DynMemRef *getDynMemRef(OrderedDynMemRefDict *orderedDict, int i);
// Set the i-th dynmemref in orderedDict to be dynMemRef. // Set the i-th dynmemref in orderedDict to be dynMemRef.
void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx, void setDynMemRef(
DynMemRef *dynMemRef); OrderedDynMemRefDict *tensorDict, int idx, DynMemRef *dynMemRef);
// Get data pointer from dynMemRef. // Get data pointer from dynMemRef.
void *getData(DynMemRef *dynMemRef); void *getData(DynMemRef *dynMemRef);

View File

@ -1,14 +1,14 @@
#include "Runtime.hpp" #include "Runtime.hpp"
ExecutionSession::ExecutionSession(std::string sharedLibPath, ExecutionSession::ExecutionSession(
std::string entryPointName) { std::string sharedLibPath, std::string entryPointName) {
_sharedLibraryHandle = dlopen(sharedLibPath.c_str(), RTLD_LAZY); _sharedLibraryHandle = dlopen(sharedLibPath.c_str(), RTLD_LAZY);
_entryPointFunc = _entryPointFunc =
(entryPointFuncType)dlsym(_sharedLibraryHandle, entryPointName.c_str()); (entryPointFuncType)dlsym(_sharedLibraryHandle, entryPointName.c_str());
} }
std::vector<py::array> std::vector<py::array> ExecutionSession::run(
ExecutionSession::run(std::vector<py::array> inputsPyArray) { std::vector<py::array> inputsPyArray) {
assert(_entryPointFunc && "entry point not loaded"); assert(_entryPointFunc && "entry point not loaded");
auto *wrappedInput = createOrderedDynMemRefDict(); auto *wrappedInput = createOrderedDynMemRefDict();
int inputIdx = 0; int inputIdx = 0;
@ -40,8 +40,8 @@ ExecutionSession::run(std::vector<py::array> inputsPyArray) {
auto *wrappedOutput = _entryPointFunc(wrappedInput); auto *wrappedOutput = _entryPointFunc(wrappedInput);
for (int i = 0; i < numDynMemRefs(wrappedOutput); i++) { for (int i = 0; i < numDynMemRefs(wrappedOutput); i++) {
auto *dynMemRef = getDynMemRef(wrappedOutput, i); auto *dynMemRef = getDynMemRef(wrappedOutput, i);
auto shape = std::vector<int64_t>(dynMemRef->sizes, auto shape = std::vector<int64_t>(
dynMemRef->sizes + dynMemRef->rank); dynMemRef->sizes, dynMemRef->sizes + dynMemRef->rank);
outputPyArrays.emplace_back( outputPyArrays.emplace_back(
py::array(py::dtype("float32"), shape, dynMemRef->data)); py::array(py::dtype("float32"), shape, dynMemRef->data));
} }

View File

@ -144,9 +144,9 @@ public:
assert(krnlGlobalOp.value().hasValue() && assert(krnlGlobalOp.value().hasValue() &&
"Krnl Global must always have a value"); "Krnl Global must always have a value");
global = rewriter.create<LLVM::GlobalOp>(loc, global = rewriter.create<LLVM::GlobalOp>(loc, llvmGlobalType,
llvmGlobalType, /*isConstant=*/true, /*isConstant=*/true, LLVM::Linkage::Internal, name,
LLVM::Linkage::Internal, name, krnlGlobalOp.value().getValue()); krnlGlobalOp.value().getValue());
} }
// Some frequently used types. // Some frequently used types.

View File

@ -75,7 +75,7 @@ public:
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
auto *context = &getContext(); auto *context = &getContext();
ConstantOp::getCanonicalizationPatterns(patterns, context); ConstantOp::getCanonicalizationPatterns(patterns, context);
applyPatternsAndFoldGreedily(f, patterns); applyPatternsAndFoldGreedily(f, patterns);
} }
}; };
} // end anonymous namespace } // end anonymous namespace

View File

@ -12,35 +12,35 @@
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include <numeric>
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
#include <numeric>
using namespace mlir; using namespace mlir;
namespace { namespace {
/// Include the patterns defined in the Declarative Rewrite framework. /// Include the patterns defined in the Declarative Rewrite framework.
#include "src/Transform/ONNX/ONNXCombine.inc" #include "src/Transform/ONNX/ONNXCombine.inc"
} // end anonymous namespace } // end anonymous namespace
/// Register optimization patterns as "canonicalization" patterns /// Register optimization patterns as "canonicalization" patterns
/// on the ONNXMatMultOp. /// on the ONNXMatMultOp.
void ONNXAddOp::getCanonicalizationPatterns( void ONNXAddOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) { OwningRewritePatternList &results, MLIRContext *context) {
results.insert<MulAddToGemmOptPattern>(context); results.insert<MulAddToGemmOptPattern>(context);
} }
void ONNXGemmOp::getCanonicalizationPatterns( void ONNXGemmOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) { OwningRewritePatternList &results, MLIRContext *context) {
results.insert<FuseGemmFollowedByAddition>(context); results.insert<FuseGemmFollowedByAddition>(context);
} }
/// on the ONNXIdentityOp. /// on the ONNXIdentityOp.
void ONNXIdentityOp::getCanonicalizationPatterns( void ONNXIdentityOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) { OwningRewritePatternList &results, MLIRContext *context) {
results.insert<IdentityEliminationPattern>(context); results.insert<IdentityEliminationPattern>(context);
} }
///on the ONNXPadConstantValueOp. /// on the ONNXPadConstantValueOp.
void ONNXPadConstantValueOp::getCanonicalizationPatterns( void ONNXPadConstantValueOp::getCanonicalizationPatterns(
OwningRewritePatternList& result, MLIRContext* context) { OwningRewritePatternList &result, MLIRContext *context) {
result.insert<ConstantPadPattern>(context); result.insert<ConstantPadPattern>(context);
} }

View File

@ -27,7 +27,8 @@ namespace {
/// Include the patterns defined in the Declarative Rewrite framework. /// Include the patterns defined in the Declarative Rewrite framework.
#include "src/Transform/ONNX/ONNXDecompose.inc" #include "src/Transform/ONNX/ONNXDecompose.inc"
struct DecomposeONNXToONNXPass : public PassWrapper<DecomposeONNXToONNXPass, FunctionPass> { struct DecomposeONNXToONNXPass
: public PassWrapper<DecomposeONNXToONNXPass, FunctionPass> {
void runOnFunction() final; void runOnFunction() final;
}; };
} // end anonymous namespace. } // end anonymous namespace.

View File

@ -9,10 +9,10 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "mlir/IR/StandardTypes.h"
#include "src/Interface/ShapeInferenceInterface.hpp" #include "src/Interface/ShapeInferenceInterface.hpp"
#include "src/Pass/Passes.hpp" #include "src/Pass/Passes.hpp"
@ -25,7 +25,8 @@ namespace {
* candidate operations and propagating the shape information until the list * candidate operations and propagating the shape information until the list
* of operations is empty [credit MLIR authors]. * of operations is empty [credit MLIR authors].
*/ */
class ShapeInferencePass : public mlir::PassWrapper<ShapeInferencePass, mlir::FunctionPass> { class ShapeInferencePass
: public mlir::PassWrapper<ShapeInferencePass, mlir::FunctionPass> {
public: public:
void runOnFunction() override { void runOnFunction() override {
auto f = getFunction(); auto f = getFunction();
@ -63,8 +64,7 @@ public:
if (auto terminator_op = f.getBody().back().getTerminator()) { if (auto terminator_op = f.getBody().back().getTerminator()) {
auto results = terminator_op->getOperandTypes(); auto results = terminator_op->getOperandTypes();
f.setType(FunctionType::get( f.setType(FunctionType::get(f.getType().getInputs(),
f.getType().getInputs(),
std::vector<Type>(results.begin(), results.end()), f.getContext())); std::vector<Type>(results.begin(), results.end()), f.getContext()));
} }
} }
@ -146,5 +146,5 @@ std::unique_ptr<mlir::Pass> mlir::createShapeInferencePass() {
return std::make_unique<ShapeInferencePass>(); return std::make_unique<ShapeInferencePass>();
} }
static PassRegistration<ShapeInferencePass> static PassRegistration<ShapeInferencePass> pass(
pass("shape-inference", "Shape inference for frontend dialects."); "shape-inference", "Shape inference for frontend dialects.");

View File

@ -14,30 +14,30 @@ using namespace onnx_mlir;
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
registerDialects(); registerDialects();
llvm::cl::OptionCategory OnnxMlirOptions("ONNX MLIR Options", llvm::cl::OptionCategory OnnxMlirOptions(
"These are frontend options."); "ONNX MLIR Options", "These are frontend options.");
llvm::cl::opt<string> inputFilename( llvm::cl::opt<string> inputFilename(llvm::cl::Positional,
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"), llvm::cl::desc("<input file>"), llvm::cl::init("-"),
llvm::cl::cat(OnnxMlirOptions)); llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<EmissionTargetType> emissionTarget( llvm::cl::opt<EmissionTargetType> emissionTarget(
llvm::cl::desc("Choose target to emit:"), llvm::cl::desc("Choose target to emit:"),
llvm::cl::values( llvm::cl::values(
clEnumVal(EmitONNXBasic, clEnumVal(EmitONNXBasic,
"Ingest ONNX and emit the basic ONNX operations without" "Ingest ONNX and emit the basic ONNX operations without"
"inferred shapes."), "inferred shapes."),
clEnumVal(EmitONNXIR, clEnumVal(
"Ingest ONNX and emit corresponding ONNX dialect."), EmitONNXIR, "Ingest ONNX and emit corresponding ONNX dialect."),
clEnumVal(EmitMLIR, clEnumVal(
"Lower model to MLIR built-in transformation dialect."), EmitMLIR, "Lower model to MLIR built-in transformation dialect."),
clEnumVal(EmitLLVMIR, "Lower model to LLVM IR (LLVM dialect)."), clEnumVal(EmitLLVMIR, "Lower model to LLVM IR (LLVM dialect)."),
clEnumVal(EmitLLVMBC, "Lower model to LLVM IR and emit (to file) " clEnumVal(EmitLLVMBC, "Lower model to LLVM IR and emit (to file) "
"LLVM bitcode for model.")), "LLVM bitcode for model.")),
llvm::cl::init(EmitLLVMBC), llvm::cl::cat(OnnxMlirOptions)); llvm::cl::init(EmitLLVMBC), llvm::cl::cat(OnnxMlirOptions));
llvm::cl::HideUnrelatedOptions(OnnxMlirOptions); llvm::cl::HideUnrelatedOptions(OnnxMlirOptions);
llvm::cl::ParseCommandLineOptions(argc, argv, llvm::cl::ParseCommandLineOptions(
"ONNX MLIR modular optimizer driver\n"); argc, argv, "ONNX MLIR modular optimizer driver\n");
mlir::MLIRContext context; mlir::MLIRContext context;
mlir::OwningModuleRef module; mlir::OwningModuleRef module;