[NFC] Set up clang-format Github Action (#119)

* Run clang-format on all source code.

* Add Clang-Format Github Action.

* Apply patch produced by Clang-Format Bot.

* nit.

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Tian Jin 2020-05-13 22:37:51 +08:00 committed by GitHub
parent 24343177b8
commit 7f2bffb27d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 272 additions and 260 deletions

40
.github/workflows/main.yml vendored Normal file
View File

@ -0,0 +1,40 @@
# This is a basic workflow to help you get started with Actions
name: Clang-Format Bot
# Controls when the action will run. Triggers the workflow on push or pull request
# events but only for the master branch
on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
# A workflow run is made up of one or more jobs that can run sequentially or in parallel
jobs:
# This workflow contains a single job called "build"
build:
# The type of runner that the job will run on
runs-on: ubuntu-latest
# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2
- name: clang-format lint
uses: DoozyX/clang-format-lint-action@v0.5
with:
# Source folder to check formatting
source: ./src
# Version of clang-format
clangFormatVersion: 9 # optional, default is 9
# Runs a single command using the runners shell
- name: Run a one-line script
run: echo Hello, world!
# Runs a set of commands using the runners shell
- name: Run a multi-line script
run: |
echo Add other actions to build,
echo test, and deploy your project.

View File

@ -12,8 +12,8 @@
namespace onnx_mlir {
void replaceAll(std::string &str, const std::string &from,
const std::string &to) {
void replaceAll(
std::string &str, const std::string &from, const std::string &to) {
if (from.empty())
return;
size_t start_pos = 0;
@ -121,7 +121,6 @@ void InitializedTensorMapping::AddMapping(
nameToInitializedTensor.emplace(name, tensor);
}
bool InitializedTensorMapping::ContainKey(std::string name) {
return nameToInitializedTensor.count(name) != 0;
}
@ -132,8 +131,8 @@ mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor(
onnx::TensorProto initializer = GetInitializedTensor(name);
// Tensor dimensions.
llvm::ArrayRef<int64_t> tensorDims(initializer.dims().data(),
initializer.dims().size());
llvm::ArrayRef<int64_t> tensorDims(
initializer.dims().data(), initializer.dims().size());
// Emit ConstantOp and record the mapping between the input and
// the constant value.
@ -142,33 +141,32 @@ mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor(
mlir::Type elementType;
mlir::ShapedType tensorType;
switch (initializer.data_type()) {
case (onnx::TensorProto::FLOAT): {
const auto& arrayAttrInitializer =
CreateArrayAttribute<float>(initializer);
elementType = builder.getF32Type();
tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
constantDenseAttribute = mlir::DenseElementsAttr::get(
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
break;
}
case (onnx::TensorProto::INT32): {
const auto& arrayAttrInitializer =
CreateArrayAttribute<int32_t>(initializer);
elementType = builder.getIntegerType(32);
tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
constantDenseAttribute = mlir::DenseElementsAttr::get(
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
break;
}
case (onnx::TensorProto::INT64): {
const auto& arrayAttrInitializer =
CreateArrayAttribute<int64_t>(initializer);
elementType = builder.getIntegerType(64);
tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
constantDenseAttribute = mlir::DenseElementsAttr::get(
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
break;
}
case (onnx::TensorProto::FLOAT): {
const auto &arrayAttrInitializer = CreateArrayAttribute<float>(initializer);
elementType = builder.getF32Type();
tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
constantDenseAttribute = mlir::DenseElementsAttr::get(
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
break;
}
case (onnx::TensorProto::INT32): {
const auto &arrayAttrInitializer =
CreateArrayAttribute<int32_t>(initializer);
elementType = builder.getIntegerType(32);
tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
constantDenseAttribute = mlir::DenseElementsAttr::get(
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
break;
}
case (onnx::TensorProto::INT64): {
const auto &arrayAttrInitializer =
CreateArrayAttribute<int64_t>(initializer);
elementType = builder.getIntegerType(64);
tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
constantDenseAttribute = mlir::DenseElementsAttr::get(
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
break;
}
}
// Create ConstantOp for dense array.

View File

@ -20,8 +20,8 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
@ -37,11 +37,12 @@
#endif
#include "onnx/onnx_pb.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
void replaceAll(std::string &str, const std::string &from,
const std::string &to);
void replaceAll(
std::string &str, const std::string &from, const std::string &to);
std::string legalize_name(std::string name);
@ -86,14 +87,14 @@ struct InitializedTensorMapping {
// This will allow the propagation of shape information passed in as an
// argument to operations such as Reshape and will enable other
// optimizations such as constant folding.
mlir::Value EmitInitializerForInputTensor(mlir::Location loc,
mlir::OpBuilder &builder, std::string name);
mlir::Value EmitInitializerForInputTensor(
mlir::Location loc, mlir::OpBuilder &builder, std::string name);
// Get initialized tensor.
onnx::TensorProto& GetInitializedTensor(std::string name) {
assert(nameToInitializedTensor.find(name) !=
nameToInitializedTensor.end() &&
"Tensor initializer not found");
onnx::TensorProto &GetInitializedTensor(std::string name) {
assert(
nameToInitializedTensor.find(name) != nameToInitializedTensor.end() &&
"Tensor initializer not found");
return nameToInitializedTensor.at(name);
}

View File

@ -127,8 +127,8 @@ private:
* @param input onnx input tensor ValueInfoProto.
* @param symbol mlir input argument.
*/
void ImportInputTensorSymbol(const onnx::ValueInfoProto &input,
mlir::Value symbol) {
void ImportInputTensorSymbol(
const onnx::ValueInfoProto &input, mlir::Value symbol) {
auto input_tensor_legalized_name = legalize_name(input.name());
assert(!frontend_symbols_.ContainKey(input_tensor_legalized_name) &&
"Found duplicate legalized input tensor names.");
@ -136,8 +136,7 @@ private:
}
typedef bstd::variant<int64_t, std::vector<int64_t>, float,
std::vector<float>, std::string,
std::vector<std::string>>
std::vector<float>, std::string, std::vector<std::string>>
AttrValueType;
struct ONNXAttrVisitor {
@ -213,8 +212,8 @@ private:
llvm_unreachable("Failed to convert attribute proto to name/value pair");
}
std::vector<mlir::NamedAttribute>
ImportNodeAttributes(const onnx::NodeProto &node) {
std::vector<mlir::NamedAttribute> ImportNodeAttributes(
const onnx::NodeProto &node) {
std::vector<mlir::NamedAttribute> attributes;
for (int i = 0; i < node.attribute_size(); ++i) {
auto attr = node.attribute(i);
@ -281,26 +280,25 @@ private:
// TODO: Handle optional inputs.
auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
for (int i = 0; i < node.output().size(); i++) {
frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
*(op.getODSResults(i).begin()));
frontend_symbols_.AddMapping(
legalize_name(node.output()[i]), *(op.getODSResults(i).begin()));
}
}
template <typename T>
void buildOperation(const onnx::NodeProto &node,
int expectedNumOperands = -1,
int expectedNumResults = -1) {
void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1,
int expectedNumResults = -1) {
std::vector<mlir::Value> inputs;
for (const auto &item : node.input())
if (initializedTensors.ContainKey(legalize_name(item))) {
inputs.push_back(initializedTensors.EmitInitializerForInputTensor(
UnknownLoc(), builder_, legalize_name(item)));
UnknownLoc(), builder_, legalize_name(item)));
} else if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
}
buildOutputAndOperation<T>(node, inputs, expectedNumOperands,
expectedNumResults);
buildOutputAndOperation<T>(
node, inputs, expectedNumOperands, expectedNumResults);
}
void ImportNodeReshape(onnx::NodeProto node, int nIn, int nOut) {
@ -310,9 +308,8 @@ private:
item = node.input()[i];
// For the second argument, check if there exists an initializer.
if (initializedTensors.ContainKey(legalize_name(item))) {
inputs.push_back(
initializedTensors.EmitInitializerForInputTensor(
UnknownLoc(), builder_, legalize_name(item)));
inputs.push_back(initializedTensors.EmitInitializerForInputTensor(
UnknownLoc(), builder_, legalize_name(item)));
} else if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
}
@ -372,7 +369,6 @@ private:
#if INCLUDE_ONNX_ML == 1
#include "src/Builder/MLOpBuildTable.inc"
#endif
}
/*!
@ -388,8 +384,8 @@ private:
* output tensor.
*/
void ImportOutputTensor(const onnx::ValueInfoProto &output,
llvm::SmallVectorImpl<mlir::Type> &ret_types,
llvm::SmallVectorImpl<mlir::Value> &ret_vals) {
llvm::SmallVectorImpl<mlir::Type> &ret_types,
llvm::SmallVectorImpl<mlir::Value> &ret_vals) {
auto output_tensor_legalized_name = legalize_name(output.name());
assert(frontend_symbols_.ContainKey(output_tensor_legalized_name) &&
"Output tensor not found");
@ -400,8 +396,8 @@ private:
ret_vals.push_back(tensor_val);
}
void ImportGraph(const onnx::GraphProto &graph,
const std::string &name = "main_graph") {
void ImportGraph(
const onnx::GraphProto &graph, const std::string &name = "main_graph") {
// Maintain a mapping between the parameter and its initializer.
for (auto initializer : graph.initializer()) {
auto name = initializer.name();
@ -426,8 +422,7 @@ private:
// Emit the entry point operation which specifies the number of user
// inputs and outputs.
auto entryPoint = mlir::ONNXEntryPointOp::create(
UnknownLoc(), mainFunc,
auto entryPoint = mlir::ONNXEntryPointOp::create(UnknownLoc(), mainFunc,
/*numInputs=*/graph.input().size() - graph.initializer().size(),
/*numOutputs=*/graph.output().size());
@ -454,8 +449,8 @@ private:
// Create a NoneTyped constant to be used for optional operation inputs
// which are not used.
none_ = builder_.create<mlir::ConstantOp>(UnknownLoc(),
builder_.getUnitAttr());
none_ =
builder_.create<mlir::ConstantOp>(UnknownLoc(), builder_.getUnitAttr());
// Import nodes in the graph.
for (const auto &item : graph.node()) {
@ -483,8 +478,7 @@ private:
namespace onnx_mlir {
void ImportFrontendModelFile(std::string model_fname,
mlir::MLIRContext &context,
mlir::OwningModuleRef &module) {
mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
onnx::ModelProto model;
std::fstream input(model_fname, std::ios::in | std::ios::binary);

View File

@ -36,11 +36,10 @@ namespace onnx_mlir {
* @return MLIR::module generated for the ONNX model.
*/
void ImportFrontendModelFile(std::string model_fname,
mlir::MLIRContext &context,
mlir::OwningModuleRef &module);
mlir::MLIRContext &context, mlir::OwningModuleRef &module);
/*!
* TODO: Import models into other extension dialects that cover the
* operations specific to other frameworks such as Tensorflow or Pytorch.
*/
} // namespace onnx_mlir
} // namespace onnx_mlir

View File

@ -1,4 +1,5 @@
//====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering --------===//
//====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering
//--------===//
//
// Copyright 2019 The IBM Research Authors.
//
@ -21,10 +22,9 @@ class ONNXEntryPointLowering : public OpRewritePattern<ONNXEntryPointOp> {
public:
using OpRewritePattern<ONNXEntryPointOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ONNXEntryPointOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<KrnlEntryPointOp>(
op,
LogicalResult matchAndRewrite(
ONNXEntryPointOp op, PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<KrnlEntryPointOp>(op,
op.getAttrOfType<SymbolRefAttr>(
ONNXEntryPointOp::getEntryPointFuncAttrName()),
op.getAttrOfType<IntegerAttr>(ONNXEntryPointOp::getNumInputsAttrName()),
@ -55,8 +55,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
// We define the specific operations, or dialects, that are legal targets for
// this lowering.
target
.addLegalDialect<KrnlOpsDialect, AffineDialect, StandardOpsDialect>();
target.addLegalDialect<KrnlOpsDialect, AffineDialect, StandardOpsDialect>();
// TODO: enable this once more ops are supported.
// We also define the ONNX dialect as Illegal so that the conversion will fail
@ -81,8 +80,8 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
// Type conversion for function signatures.
// Call MLIR FuncOp signature conversion when result type is
// a ranked tensor.
populateFuncOpTypeConversionPattern(patterns, &getContext(),
tensor_to_memref_converter);
populateFuncOpTypeConversionPattern(
patterns, &getContext(), tensor_to_memref_converter);
// Frontend operation lowering.
// Math
@ -119,5 +118,5 @@ std::unique_ptr<Pass> mlir::createLowerToKrnlPass() {
return std::make_unique<FrontendToKrnlLoweringPass>();
}
static PassRegistration<FrontendToKrnlLoweringPass>
pass("lower-frontend", "Lower frontend ops to Krnl dialect.");
static PassRegistration<FrontendToKrnlLoweringPass> pass(
"lower-frontend", "Lower frontend ops to Krnl dialect.");

View File

@ -499,9 +499,8 @@ template <typename ElementwiseUnaryOp>
struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
ONNXElementwiseUnaryOpLowering(MLIRContext *ctx)
: ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// TODO: Check that the types are valid.
// An element-wise unary operation must have all operands and the result of
// the same type. This should have been verified by the verifier.
@ -566,9 +565,8 @@ template <typename ElementwiseVariadicOp>
struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
ONNXElementwiseVariadicOpLowering(MLIRContext *ctx)
: ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// TODO: Check that the types are valid.
// An element-wise variadic operation must have all operands and the result
// of the same type. This should have been verified by the verifier.

View File

@ -1,4 +1,5 @@
//===----------------- Gemm.cpp - Lowering Gemm Op -------------------------===//
//===----------------- Gemm.cpp - Lowering Gemm Op
//-------------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
@ -17,9 +18,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
ONNXGemmOpLowering(MLIRContext *ctx)
: ConversionPattern(GemmOp::getOperationName(), 1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
bool hasBias = !op->getOperand(2).getType().isa<NoneType>();
@ -32,12 +32,10 @@ struct ONNXGemmOpLowering : public ConversionPattern {
auto memRefType = convertToMemRefType(*op->result_type_begin());
auto alphaAttr =
FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat());
auto betaAttr =
FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<GemmOp>(op).beta().convertToFloat());
auto alphaAttr = FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat());
auto betaAttr = FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<GemmOp>(op).beta().convertToFloat());
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
@ -101,8 +99,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
optimizedReductionLoops.reserve(1);
reductionLoops.push_back(originalLoops[2]);
optimizedReductionLoops.push_back(optimizedLoops[2]);
KrnlIterateOperandPack reductionPack(rewriter, reductionLoops,
optimizedReductionLoops);
KrnlIterateOperandPack reductionPack(
rewriter, reductionLoops, optimizedReductionLoops);
// Induction variable for the reduction dimension
// Try to find and use a static value from A or B first.
// If it failed then use a dynamic value.
@ -167,8 +165,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
if (hasBias) {
auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C,
broadcastedDimInfo);
auto loopCIVs = getLoopIVsForBroadcasting(
loc, rewriter, loopMNIVs, C, broadcastedDimInfo);
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
auto betaC = rewriter.create<MulFOp>(loc, beta, loadedC);
auto Y = rewriter.create<AddFOp>(loc, alphaAB, betaC);
@ -214,7 +212,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
}
};
void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
MLIRContext *ctx) {
void populateLoweringONNXGemmOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
}

View File

@ -16,9 +16,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
ONNXMatMulOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
ONNXMatMulOpOperandAdaptor operandAdaptor(operands);
@ -119,8 +118,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
// Define loops for batch dimensions.
std::vector<Value> originalLoops;
std::vector<Value> optimizedLoops;
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
optimizedLoops, memRefShape.size());
Block *optimizationBlock = defineLoops(
rewriter, loc, originalLoops, optimizedLoops, memRefShape.size());
// Outer KrnlIterateOp
SmallVector<Value, 4> loopBatchIVs;
@ -139,8 +138,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
outerLoops.push_back(originalLoops[i]);
optimizedOuterLoops.push_back(optimizedLoops[i]);
}
KrnlIterateOperandPack outerPack(rewriter, outerLoops,
optimizedOuterLoops);
KrnlIterateOperandPack outerPack(
rewriter, outerLoops, optimizedOuterLoops);
for (int i = 0; i < batchAxes.size(); ++i) {
addDimensionToPack(rewriter, loc, outerPack, alloc, i);
}
@ -176,11 +175,11 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
optimizedMatmulLoops.emplace_back(
optimizedLoops[memRefShape.size() - i]);
}
KrnlIterateOperandPack matmulPack(rewriter, matmulLoops,
optimizedMatmulLoops);
KrnlIterateOperandPack matmulPack(
rewriter, matmulLoops, optimizedMatmulLoops);
for (int i = 2; i > 0; --i) {
addDimensionToPack(rewriter, loc, matmulPack, alloc,
memRefShape.size() - i);
addDimensionToPack(
rewriter, loc, matmulPack, alloc, memRefShape.size() - i);
}
matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack);
} else {
@ -190,10 +189,10 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
matmulLoops.emplace_back(originalLoops[memRefShape.size() - 1]);
optimizedMatmulLoops.emplace_back(
optimizedLoops[memRefShape.size() - 1]);
KrnlIterateOperandPack matmulPack(rewriter, matmulLoops,
optimizedMatmulLoops);
addDimensionToPack(rewriter, loc, matmulPack, alloc,
memRefShape.size() - 1);
KrnlIterateOperandPack matmulPack(
rewriter, matmulLoops, optimizedMatmulLoops);
addDimensionToPack(
rewriter, loc, matmulPack, alloc, memRefShape.size() - 1);
matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack);
}
@ -230,8 +229,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
std::vector<Value> optimizedReduceLoops;
Block *optimizationReduceBlock =
defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1);
KrnlIterateOperandPack reducePack(rewriter, reduceLoops,
optimizedReduceLoops);
KrnlIterateOperandPack reducePack(
rewriter, reduceLoops, optimizedReduceLoops);
addDimensionToPack(rewriter, loc, reducePack, A, AShape.size() - 1);
auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack);
@ -292,8 +291,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
std::vector<Value> optimizedReduceLoops;
Block *optimizationReduceBlock =
defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1);
KrnlIterateOperandPack reducePack(rewriter, reduceLoops,
optimizedReduceLoops);
KrnlIterateOperandPack reducePack(
rewriter, reduceLoops, optimizedReduceLoops);
addDimensionToPack(rewriter, loc, reducePack, A, 0);
auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack);

View File

@ -102,9 +102,8 @@ struct ONNXReductionOpLowering : public ConversionPattern {
ONNXReductionOpLowering(MLIRContext *ctx)
: ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
/*
* Condition: reduction function must be associative and commutative.
*

View File

@ -1,4 +1,5 @@
//===--------------- Conv.cpp - Lowering Convolution Op --------------------===//
//===--------------- Conv.cpp - Lowering Convolution Op
//--------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
@ -175,14 +176,12 @@ struct ONNXConvOpLowering : public ConversionPattern {
// Emit the bias, if needed.
if (hasBias) {
auto loadResult =
rewriter.create<LoadOp>(loc, alloc, resultIndices);
auto loadResult = rewriter.create<LoadOp>(loc, alloc, resultIndices);
SmallVector<Value, 4> biasIndices;
biasIndices.emplace_back(kernel);
auto loadBias =
rewriter.create<LoadOp>(loc, biasOperand, kernel);
auto resultWithBias = rewriter.create<MulFOp>(
loc, loadResult, loadBias);
auto loadBias = rewriter.create<LoadOp>(loc, biasOperand, kernel);
auto resultWithBias =
rewriter.create<MulFOp>(loc, loadResult, loadBias);
// Store initializer value into output location.
rewriter.create<StoreOp>(loc, resultWithBias, alloc, resultIndices);
}

View File

@ -459,7 +459,8 @@ struct ONNXPoolOpLowering : public ConversionPattern {
poolDimMap = AffineMap::get(1, 5, dimExpr, rewriter.getContext());
// poolStartMap and poolEndMap
poolStartMap = AffineMap::get(1, 5, {start1, start2}, rewriter.getContext());
poolStartMap =
AffineMap::get(1, 5, {start1, start2}, rewriter.getContext());
poolEndMap = AffineMap::get(1, 5, {end1, end2}, rewriter.getContext());
}

View File

@ -36,9 +36,7 @@ MemRefType convertToMemRefType(Type type) {
/// Insert an allocation and deallocation for the given MemRefType.
Value insertAllocAndDealloc(MemRefType type, Location loc,
PatternRewriter &rewriter,
bool insertDealloc,
ArrayRef<Value> operands) {
PatternRewriter &rewriter, bool insertDealloc, ArrayRef<Value> operands) {
// Put together alloc operands for any dynamic dimensions of the memref.
AllocOp alloc;
if (!operands.empty()) {
@ -64,10 +62,10 @@ Value insertAllocAndDealloc(MemRefType type, Location loc,
auto operandDim =
rewriter.create<DimOp>(loc, operands[i], operandDimIdx);
if (maxDim) {
auto maxCondition = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt,
operandDim, maxDim);
maxDim = rewriter.create<SelectOp>(loc, maxCondition, operandDim,
maxDim);
auto maxCondition = rewriter.create<CmpIOp>(
loc, CmpIPredicate::sgt, operandDim, maxDim);
maxDim = rewriter.create<SelectOp>(
loc, maxCondition, operandDim, maxDim);
} else {
maxDim = operandDim;
}
@ -122,8 +120,8 @@ bool checkInsertDealloc(Operation *currentOp, int resultIndex) {
// Create a mapping from result type's dimensions to input type's dimensions,
// given that the result type is the result of a reduction op over the input
// type.
std::map<int64_t, int64_t>
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
std::map<int64_t, int64_t> getReductionMapping(
MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
std::map<int64_t, int64_t> OutInDimMap;
int64_t rank = inputTy.getRank();
@ -152,9 +150,8 @@ getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
// Add bounds associated with the op operand to the KRNL iteration pack.
// Dynamic dimenions are supported.
void addDimensionToPack(ConversionPatternRewriter &rewriter,
Location loc, KrnlIterateOperandPack &pack,
Value operand, int index) {
void addDimensionToPack(ConversionPatternRewriter &rewriter, Location loc,
KrnlIterateOperandPack &pack, Value operand, int index) {
auto shape = operand.getType().cast<MemRefType>().getShape();
if (shape[index] < 0) {
pack.pushConstantBound(0);
@ -168,10 +165,9 @@ void addDimensionToPack(ConversionPatternRewriter &rewriter,
// Function that defines the KRNL dialect loops and their respective
// optimized version.
KrnlOptimizeLoopsOp
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
std::vector<Value> &loops,
std::vector<Value> &optimizedLoops, int64_t numLoops) {
KrnlOptimizeLoopsOp emitOptimizedLoops(ConversionPatternRewriter &rewriter,
Location loc, std::vector<Value> &loops, std::vector<Value> &optimizedLoops,
int64_t numLoops) {
// Define loops.
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops);
loops.reserve(numLoops);
@ -190,9 +186,8 @@ emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
// Function that emits the loops and their optimized version.
// The function returns a reference to the inner optimization block.
Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
std::vector<Value> &loops,
std::vector<Value> &optimizedLoops,
int64_t numLoops) {
std::vector<Value> &loops, std::vector<Value> &optimizedLoops,
int64_t numLoops) {
KrnlOptimizeLoopsOp optimizedLoopsOp =
emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops);
return &optimizedLoopsOp.region().front();
@ -201,10 +196,9 @@ Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
// Function which emits a basic set of loops and optimized loops
// for a given operation argument. A reference to the loop optimization
// block is returned in the last argument of the function.
void emitKrnlLoopsAndIterationForOperand(
ConversionPatternRewriter &rewriter, Location loc, Value operand,
std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
KrnlIterateOp &iterateOp) {
void emitKrnlLoopsAndIterationForOperand(ConversionPatternRewriter &rewriter,
Location loc, Value operand, std::vector<Value> &originalLoops,
KrnlOptimizeLoopsOp &optimizedLoopsOp, KrnlIterateOp &iterateOp) {
// Operand shape.
auto shape = operand.getType().cast<MemRefType>().getShape();
@ -240,9 +234,9 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
// Get run-time dimension information for unknown dimensions used for
// broadcasting.
std::map<int, std::map<int, Value>>
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
MemRefType memRefType, ArrayRef<Value> operands) {
std::map<int, std::map<int, Value>> getBroadcastedDimInfo(Location loc,
ConversionPatternRewriter &rewriter, MemRefType memRefType,
ArrayRef<Value> operands) {
auto memRefShape = memRefType.getShape();
int64_t rank = memRefShape.size();
// For unknown dimensions, we need to get dimension values at runtime in
@ -286,10 +280,9 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
// Extract induction variables that are used for broadcasting values of a
// given operand.
std::vector<Value>
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
ArrayRef<Value> loopIVs, Value operand,
std::map<int, Value> broadcastedDims) {
std::vector<Value> getLoopIVsForBroadcasting(Location loc,
ConversionPatternRewriter &rewriter, ArrayRef<Value> loopIVs, Value operand,
std::map<int, Value> broadcastedDims) {
// `operand` must has a ranked type. This should have been checked by the
// shape inference pass.
auto operandShape = operand.getType().cast<MemRefType>().getShape();
@ -310,8 +303,8 @@ getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
// If its value is 1, it is broadcasted dimension.
// Otherwise, non-broadcasted dimension.
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], zero,
loopIVs[loopIdx]);
auto idx = rewriter.create<SelectOp>(
loc, broadcastedDims[dimIdx], zero, loopIVs[loopIdx]);
newLoopIVs.insert(newLoopIVs.begin(), idx);
} else {
// Non-broadcasted dimension

View File

@ -30,7 +30,7 @@ struct ONNXConcatOpLowering : public ConversionPattern {
auto memRefType = convertToMemRefType(*op->result_type_begin());
auto resultShape = memRefType.getShape();
auto rank = resultShape.size();
assert((axis >=0 && axis < rank) && "Concat axis out of bounds");
assert((axis >= 0 && axis < rank) && "Concat axis out of bounds");
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);

View File

@ -17,8 +17,8 @@ struct ONNXConstantOpLowering : public ConversionPattern {
ONNXConstantOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXConstantOp::getOperationName(), 1, ctx) {
constantID = 0;
}
constantID = 0;
}
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
@ -34,12 +34,11 @@ struct ONNXConstantOpLowering : public ConversionPattern {
// Shape based computations.
auto shape = memRefType.getShape();
int64_t numElements = 1;
for (int i=0; i<shape.size(); ++i)
for (int i = 0; i < shape.size(); ++i)
numElements *= shape[i];
// Emit the constant global in Krnl dialect.
auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc,
memRefType,
auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc, memRefType,
rewriter.getI64ArrayAttr(shape),
rewriter.getStringAttr("constant_" + std::to_string(constantID)),
constantOp.value().getValue());

View File

@ -24,15 +24,15 @@ enum Kinds {
}
class LoopType : public mlir::Type::TypeBase<LoopType, mlir::Type> {
public:
public:
using Base::Base;
// Support type inquiry through isa, cast and dyn_cast.
static bool kindof(unsigned kind) { return kind == KrnlTypes::Loop; }
// Get a unique instance of Loop type.
static LoopType get(mlir::MLIRContext* context) {
static LoopType get(mlir::MLIRContext *context) {
return Base::get(context, KrnlTypes::Loop);
}
};
} // namespace mlir
} // namespace mlir

View File

@ -39,7 +39,6 @@ MLONNXOpsDialect::MLONNXOpsDialect(mlir::MLIRContext *ctx)
>();
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

View File

@ -19,14 +19,14 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
#include "src/Interface/ShapeInferenceInterface.hpp"
#include "src/Interface/PromotableConstOperandsOpInterface.hpp"
#include "src/Interface/ShapeInferenceInterface.hpp"
namespace mlir {
class MLONNXOpsDialect : public Dialect {
public:
MLONNXOpsDialect(MLIRContext* context);
public:
MLONNXOpsDialect(MLIRContext *context);
/// Provide a utility accessor to the dialect namespace. This is used by
/// several utilities for casting between dialects.
@ -38,6 +38,6 @@ class MLONNXOpsDialect : public Dialect {
#define GET_OP_CLASSES
#include "src/Dialect/MLONNX/MLONNXOps.hpp.inc"
} // end namespace mlir
} // end namespace mlir
namespace onnx_mlir {}

View File

@ -19,14 +19,14 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
#include "src/Interface/ShapeInferenceInterface.hpp"
#include "src/Interface/PromotableConstOperandsOpInterface.hpp"
#include "src/Interface/ShapeInferenceInterface.hpp"
namespace mlir {
class ONNXOpsDialect : public Dialect {
public:
ONNXOpsDialect(MLIRContext* context);
public:
ONNXOpsDialect(MLIRContext *context);
/// Provide a utility accessor to the dialect namespace. This is used by
/// several utilities for casting between dialects.
@ -38,6 +38,6 @@ class ONNXOpsDialect : public Dialect {
#define GET_OP_CLASSES
#include "src/Dialect/ONNX/ONNXOps.hpp.inc"
} // end namespace mlir
} // end namespace mlir
namespace onnx_mlir {}

View File

@ -10,7 +10,6 @@
//
//===----------------------------------------------------------------------===//
#include "src/Interface/PromotableConstOperandsOpInterface.hpp"
using namespace mlir;
@ -20,4 +19,3 @@ using namespace mlir;
//===----------------------------------------------------------------------===//
#include "src/Interface/PromotableConstOperandsOpInterface.cpp.inc"

View File

@ -22,4 +22,4 @@ namespace mlir {
/// Include the auto-generated declarations.
#include "src/Interface/PromotableConstOperandsOpInterface.hpp.inc"
} // end namespace mlir
} // end namespace mlir

View File

@ -16,4 +16,4 @@ namespace mlir {
/// Include the auto-generated declarations.
#include "src/Interface/ShapeInference.cpp.inc"
} // end namespace mlir
} // end namespace mlir

View File

@ -18,4 +18,4 @@ namespace mlir {
/// Include the auto-generated declarations.
#include "src/Interface/ShapeInference.hpp.inc"
} // end namespace mlir
} // end namespace mlir

View File

@ -22,7 +22,7 @@ using namespace std;
using namespace onnx_mlir;
void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
mlir::OwningModuleRef &module) {
mlir::OwningModuleRef &module) {
// Handle '.mlir' input to the ONNX MLIR frontend.
// The mlir format indicates that one or more of the supported
// representations are used in the file.
@ -46,10 +46,10 @@ void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
void EmitLLVMBitCode(
const mlir::OwningModuleRef &module, string outputFilename) {
error_code error;
llvm::raw_fd_ostream moduleBitcodeStream(outputFilename, error,
llvm::sys::fs::F_None);
llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module),
moduleBitcodeStream);
llvm::raw_fd_ostream moduleBitcodeStream(
outputFilename, error, llvm::sys::fs::F_None);
llvm::WriteBitcodeToFile(
*mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream);
moduleBitcodeStream.flush();
}
@ -90,7 +90,7 @@ void addKrnlToLLVMPasses(mlir::PassManager &pm) {
}
void processInputFile(string inputFilename, EmissionTargetType emissionTarget,
mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
// Decide if the input file is an ONNX model or a model specified
// in MLIR. The extension of the file is the decider.
string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1);
@ -99,7 +99,6 @@ void processInputFile(string inputFilename, EmissionTargetType emissionTarget,
assert(inputIsONNX != inputIsMLIR &&
"Either ONNX model or MLIR file needs to be provided.");
if (inputIsONNX) {
ImportFrontendModelFile(inputFilename, context, module);
} else {
@ -119,7 +118,7 @@ void outputCode(
module->dump();
fflush(stderr);
// set modified stderr as original stderr
_dup2(stderrOrigin, _fileno( stderr ));
_dup2(stderrOrigin, _fileno(stderr));
#else
if (fork() == 0) {
freopen(tempFilename.c_str(), "w", stderr);
@ -151,7 +150,7 @@ void emitOutputFiles(string outputBaseName, EmissionTargetType emissionTarget,
// necessary when emitting the .bc file.
if (emissionTarget == EmitLLVMBC) {
// Write LLVM bitcode to disk.
string outputFilename = outputBaseName + ".bc";
string outputFilename = outputBaseName + ".bc";
EmitLLVMBitCode(module, outputFilename);
printf("LLVM bitcode written to %s\n", outputFilename.c_str());
} else {

View File

@ -28,9 +28,9 @@
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/InitAllDialects.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/InitAllDialects.h"
#include "mlir/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
@ -46,10 +46,10 @@ enum EmissionTargetType {
};
void LoadMLIR(std::string inputFilename, mlir::MLIRContext &context,
mlir::OwningModuleRef &module);
mlir::OwningModuleRef &module);
void EmitLLVMBitCode(
const mlir::OwningModuleRef &module, std::string outputFilename);
const mlir::OwningModuleRef &module, std::string outputFilename);
void registerDialects();
@ -66,8 +66,7 @@ void processInputFile(std::string inputFilename,
mlir::OwningModuleRef &module);
void outputCode(
mlir::OwningModuleRef &module, std::string filename,
std::string extension);
mlir::OwningModuleRef &module, std::string filename, std::string extension);
void emitOutputFiles(std::string outputBaseName,
EmissionTargetType emissionTarget, mlir::MLIRContext &context,

View File

@ -38,4 +38,4 @@ std::unique_ptr<Pass> createElideConstGlobalValuePass();
/// Pass for lowering Krnl dialect to LLVM dialect.
std::unique_ptr<Pass> createKrnlLowerToLLVMPass();
} // end namespace mlir
} // end namespace mlir

View File

@ -1,15 +1,15 @@
enum DYN_MEMREF_DATA_TYPE {
UNDEFINED = 0;
// Basic types.
FLOAT = 1; // float
UINT8 = 2; // uint8_t
INT8 = 3; // int8_t
UINT16 = 4; // uint16_t
INT16 = 5; // int16_t
INT32 = 6; // int32_t
INT64 = 7; // int64_t
STRING = 8; // string
BOOL = 9; // bool
FLOAT = 1; // float
UINT8 = 2; // uint8_t
INT8 = 3; // int8_t
UINT16 = 4; // uint16_t
INT16 = 5; // int16_t
INT32 = 6; // int32_t
INT64 = 7; // int64_t
STRING = 8; // string
BOOL = 9; // bool
// IEEE754 half-precision floating-point format (16 bits wide).
// This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
@ -18,8 +18,8 @@ enum DYN_MEMREF_DATA_TYPE {
DOUBLE = 11;
UINT32 = 12;
UINT64 = 13;
COMPLEX64 = 14; // complex with float32 real and imaginary components
COMPLEX128 = 15; // complex with float64 real and imaginary components
COMPLEX64 = 14; // complex with float32 real and imaginary components
COMPLEX128 = 15; // complex with float64 real and imaginary components
// Non-IEEE floating-point format based on IEEE754 single-precision
// floating-point number truncated to 16 bits.

View File

@ -33,8 +33,8 @@ DynMemRef *getDynMemRef(OrderedDynMemRefDict *tensorDict, int idx) {
return tensorDict->tensorDict[tensorDict->orderedNames[idx]];
}
void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx,
DynMemRef *tensor) {
void setDynMemRef(
OrderedDynMemRefDict *tensorDict, int idx, DynMemRef *tensor) {
if (tensorDict->orderedNames.size() <= idx)
tensorDict->orderedNames.resize(idx + 1);

View File

@ -39,7 +39,6 @@ typedef struct _OrderedDynMemRefDict OrderedDynMemRefDict;
extern "C" {
#endif
// Get number of dynamic memrefs in OrderedDynMemRefDict dict.
int numDynMemRefs(OrderedDynMemRefDict *dict);
@ -53,8 +52,8 @@ DynMemRef *createDynMemRef(int rank);
DynMemRef *getDynMemRef(OrderedDynMemRefDict *orderedDict, int i);
// Set the i-th dynmemref in orderedDict to be dynMemRef.
void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx,
DynMemRef *dynMemRef);
void setDynMemRef(
OrderedDynMemRefDict *tensorDict, int idx, DynMemRef *dynMemRef);
// Get data pointer from dynMemRef.
void *getData(DynMemRef *dynMemRef);

View File

@ -1,14 +1,14 @@
#include "Runtime.hpp"
ExecutionSession::ExecutionSession(std::string sharedLibPath,
std::string entryPointName) {
ExecutionSession::ExecutionSession(
std::string sharedLibPath, std::string entryPointName) {
_sharedLibraryHandle = dlopen(sharedLibPath.c_str(), RTLD_LAZY);
_entryPointFunc =
(entryPointFuncType)dlsym(_sharedLibraryHandle, entryPointName.c_str());
}
std::vector<py::array>
ExecutionSession::run(std::vector<py::array> inputsPyArray) {
std::vector<py::array> ExecutionSession::run(
std::vector<py::array> inputsPyArray) {
assert(_entryPointFunc && "entry point not loaded");
auto *wrappedInput = createOrderedDynMemRefDict();
int inputIdx = 0;
@ -40,8 +40,8 @@ ExecutionSession::run(std::vector<py::array> inputsPyArray) {
auto *wrappedOutput = _entryPointFunc(wrappedInput);
for (int i = 0; i < numDynMemRefs(wrappedOutput); i++) {
auto *dynMemRef = getDynMemRef(wrappedOutput, i);
auto shape = std::vector<int64_t>(dynMemRef->sizes,
dynMemRef->sizes + dynMemRef->rank);
auto shape = std::vector<int64_t>(
dynMemRef->sizes, dynMemRef->sizes + dynMemRef->rank);
outputPyArrays.emplace_back(
py::array(py::dtype("float32"), shape, dynMemRef->data));
}

View File

@ -144,9 +144,9 @@ public:
assert(krnlGlobalOp.value().hasValue() &&
"Krnl Global must always have a value");
global = rewriter.create<LLVM::GlobalOp>(loc,
llvmGlobalType, /*isConstant=*/true,
LLVM::Linkage::Internal, name, krnlGlobalOp.value().getValue());
global = rewriter.create<LLVM::GlobalOp>(loc, llvmGlobalType,
/*isConstant=*/true, LLVM::Linkage::Internal, name,
krnlGlobalOp.value().getValue());
}
// Some frequently used types.

View File

@ -75,7 +75,7 @@ public:
OwningRewritePatternList patterns;
auto *context = &getContext();
ConstantOp::getCanonicalizationPatterns(patterns, context);
applyPatternsAndFoldGreedily(f, patterns);
applyPatternsAndFoldGreedily(f, patterns);
}
};
} // end anonymous namespace

View File

@ -12,35 +12,35 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include <numeric>
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include <numeric>
using namespace mlir;
namespace {
/// Include the patterns defined in the Declarative Rewrite framework.
#include "src/Transform/ONNX/ONNXCombine.inc"
} // end anonymous namespace
} // end anonymous namespace
/// Register optimization patterns as "canonicalization" patterns
/// on the ONNXMatMultOp.
void ONNXAddOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<MulAddToGemmOptPattern>(context);
}
void ONNXGemmOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<FuseGemmFollowedByAddition>(context);
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<FuseGemmFollowedByAddition>(context);
}
/// on the ONNXIdentityOp.
void ONNXIdentityOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<IdentityEliminationPattern>(context);
}
///on the ONNXPadConstantValueOp.
/// on the ONNXPadConstantValueOp.
void ONNXPadConstantValueOp::getCanonicalizationPatterns(
OwningRewritePatternList& result, MLIRContext* context) {
OwningRewritePatternList &result, MLIRContext *context) {
result.insert<ConstantPadPattern>(context);
}

View File

@ -27,7 +27,8 @@ namespace {
/// Include the patterns defined in the Declarative Rewrite framework.
#include "src/Transform/ONNX/ONNXDecompose.inc"
struct DecomposeONNXToONNXPass : public PassWrapper<DecomposeONNXToONNXPass, FunctionPass> {
struct DecomposeONNXToONNXPass
: public PassWrapper<DecomposeONNXToONNXPass, FunctionPass> {
void runOnFunction() final;
};
} // end anonymous namespace.

View File

@ -9,10 +9,10 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/StandardTypes.h"
#include "src/Interface/ShapeInferenceInterface.hpp"
#include "src/Pass/Passes.hpp"
@ -25,7 +25,8 @@ namespace {
* candidate operations and propagating the shape information until the list
* of operations is empty [credit MLIR authors].
*/
class ShapeInferencePass : public mlir::PassWrapper<ShapeInferencePass, mlir::FunctionPass> {
class ShapeInferencePass
: public mlir::PassWrapper<ShapeInferencePass, mlir::FunctionPass> {
public:
void runOnFunction() override {
auto f = getFunction();
@ -63,8 +64,7 @@ public:
if (auto terminator_op = f.getBody().back().getTerminator()) {
auto results = terminator_op->getOperandTypes();
f.setType(FunctionType::get(
f.getType().getInputs(),
f.setType(FunctionType::get(f.getType().getInputs(),
std::vector<Type>(results.begin(), results.end()), f.getContext()));
}
}
@ -146,5 +146,5 @@ std::unique_ptr<mlir::Pass> mlir::createShapeInferencePass() {
return std::make_unique<ShapeInferencePass>();
}
static PassRegistration<ShapeInferencePass>
pass("shape-inference", "Shape inference for frontend dialects.");
static PassRegistration<ShapeInferencePass> pass(
"shape-inference", "Shape inference for frontend dialects.");

View File

@ -14,30 +14,30 @@ using namespace onnx_mlir;
int main(int argc, char *argv[]) {
registerDialects();
llvm::cl::OptionCategory OnnxMlirOptions("ONNX MLIR Options",
"These are frontend options.");
llvm::cl::opt<string> inputFilename(
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"),
llvm::cl::OptionCategory OnnxMlirOptions(
"ONNX MLIR Options", "These are frontend options.");
llvm::cl::opt<string> inputFilename(llvm::cl::Positional,
llvm::cl::desc("<input file>"), llvm::cl::init("-"),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<EmissionTargetType> emissionTarget(
llvm::cl::desc("Choose target to emit:"),
llvm::cl::values(
clEnumVal(EmitONNXBasic,
"Ingest ONNX and emit the basic ONNX operations without"
"inferred shapes."),
clEnumVal(EmitONNXIR,
"Ingest ONNX and emit corresponding ONNX dialect."),
clEnumVal(EmitMLIR,
"Lower model to MLIR built-in transformation dialect."),
"Ingest ONNX and emit the basic ONNX operations without"
"inferred shapes."),
clEnumVal(
EmitONNXIR, "Ingest ONNX and emit corresponding ONNX dialect."),
clEnumVal(
EmitMLIR, "Lower model to MLIR built-in transformation dialect."),
clEnumVal(EmitLLVMIR, "Lower model to LLVM IR (LLVM dialect)."),
clEnumVal(EmitLLVMBC, "Lower model to LLVM IR and emit (to file) "
"LLVM bitcode for model.")),
llvm::cl::init(EmitLLVMBC), llvm::cl::cat(OnnxMlirOptions));
llvm::cl::HideUnrelatedOptions(OnnxMlirOptions);
llvm::cl::ParseCommandLineOptions(argc, argv,
"ONNX MLIR modular optimizer driver\n");
llvm::cl::ParseCommandLineOptions(
argc, argv, "ONNX MLIR modular optimizer driver\n");
mlir::MLIRContext context;
mlir::OwningModuleRef module;