diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000..418d021 --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,40 @@ +# This is a basic workflow to help you get started with Actions + +name: Clang-Format Bot + +# Controls when the action will run. Triggers the workflow on push or pull request +# events but only for the master branch +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +# A workflow run is made up of one or more jobs that can run sequentially or in parallel +jobs: + # This workflow contains a single job called "build" + build: + # The type of runner that the job will run on + runs-on: ubuntu-latest + + # Steps represent a sequence of tasks that will be executed as part of the job + steps: + # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it + - uses: actions/checkout@v2 + - name: clang-format lint + uses: DoozyX/clang-format-lint-action@v0.5 + with: + # Source folder to check formatting + source: ./src + # Version of clang-format + clangFormatVersion: 9 # optional, default is 9 + + # Runs a single command using the runners shell + - name: Run a one-line script + run: echo Hello, world! + + # Runs a set of commands using the runners shell + - name: Run a multi-line script + run: | + echo Add other actions to build, + echo test, and deploy your project. diff --git a/src/Builder/FrontendDialectHelper.cpp b/src/Builder/FrontendDialectHelper.cpp index 7cfca8e..49bce5a 100644 --- a/src/Builder/FrontendDialectHelper.cpp +++ b/src/Builder/FrontendDialectHelper.cpp @@ -12,8 +12,8 @@ namespace onnx_mlir { -void replaceAll(std::string &str, const std::string &from, - const std::string &to) { +void replaceAll( + std::string &str, const std::string &from, const std::string &to) { if (from.empty()) return; size_t start_pos = 0; @@ -121,7 +121,6 @@ void InitializedTensorMapping::AddMapping( nameToInitializedTensor.emplace(name, tensor); } - bool InitializedTensorMapping::ContainKey(std::string name) { return nameToInitializedTensor.count(name) != 0; } @@ -132,8 +131,8 @@ mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor( onnx::TensorProto initializer = GetInitializedTensor(name); // Tensor dimensions. - llvm::ArrayRef tensorDims(initializer.dims().data(), - initializer.dims().size()); + llvm::ArrayRef tensorDims( + initializer.dims().data(), initializer.dims().size()); // Emit ConstantOp and record the mapping between the input and // the constant value. @@ -142,33 +141,32 @@ mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor( mlir::Type elementType; mlir::ShapedType tensorType; switch (initializer.data_type()) { - case (onnx::TensorProto::FLOAT): { - const auto& arrayAttrInitializer = - CreateArrayAttribute(initializer); - elementType = builder.getF32Type(); - tensorType = mlir::RankedTensorType::get(tensorDims, elementType); - constantDenseAttribute = mlir::DenseElementsAttr::get( - tensorType, llvm::makeArrayRef(arrayAttrInitializer)); - break; - } - case (onnx::TensorProto::INT32): { - const auto& arrayAttrInitializer = - CreateArrayAttribute(initializer); - elementType = builder.getIntegerType(32); - tensorType = mlir::RankedTensorType::get(tensorDims, elementType); - constantDenseAttribute = mlir::DenseElementsAttr::get( - tensorType, llvm::makeArrayRef(arrayAttrInitializer)); - break; - } - case (onnx::TensorProto::INT64): { - const auto& arrayAttrInitializer = - CreateArrayAttribute(initializer); - elementType = builder.getIntegerType(64); - tensorType = mlir::RankedTensorType::get(tensorDims, elementType); - constantDenseAttribute = mlir::DenseElementsAttr::get( - tensorType, llvm::makeArrayRef(arrayAttrInitializer)); - break; - } + case (onnx::TensorProto::FLOAT): { + const auto &arrayAttrInitializer = CreateArrayAttribute(initializer); + elementType = builder.getF32Type(); + tensorType = mlir::RankedTensorType::get(tensorDims, elementType); + constantDenseAttribute = mlir::DenseElementsAttr::get( + tensorType, llvm::makeArrayRef(arrayAttrInitializer)); + break; + } + case (onnx::TensorProto::INT32): { + const auto &arrayAttrInitializer = + CreateArrayAttribute(initializer); + elementType = builder.getIntegerType(32); + tensorType = mlir::RankedTensorType::get(tensorDims, elementType); + constantDenseAttribute = mlir::DenseElementsAttr::get( + tensorType, llvm::makeArrayRef(arrayAttrInitializer)); + break; + } + case (onnx::TensorProto::INT64): { + const auto &arrayAttrInitializer = + CreateArrayAttribute(initializer); + elementType = builder.getIntegerType(64); + tensorType = mlir::RankedTensorType::get(tensorDims, elementType); + constantDenseAttribute = mlir::DenseElementsAttr::get( + tensorType, llvm::makeArrayRef(arrayAttrInitializer)); + break; + } } // Create ConstantOp for dense array. diff --git a/src/Builder/FrontendDialectHelper.hpp b/src/Builder/FrontendDialectHelper.hpp index 82d0c21..2f329da 100644 --- a/src/Builder/FrontendDialectHelper.hpp +++ b/src/Builder/FrontendDialectHelper.hpp @@ -20,8 +20,8 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/Location.h" -#include "mlir/IR/Matchers.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/Module.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" @@ -37,11 +37,12 @@ #endif #include "onnx/onnx_pb.h" +#include "src/Dialect/ONNX/ONNXOps.hpp" namespace onnx_mlir { -void replaceAll(std::string &str, const std::string &from, - const std::string &to); +void replaceAll( + std::string &str, const std::string &from, const std::string &to); std::string legalize_name(std::string name); @@ -86,14 +87,14 @@ struct InitializedTensorMapping { // This will allow the propagation of shape information passed in as an // argument to operations such as Reshape and will enable other // optimizations such as constant folding. - mlir::Value EmitInitializerForInputTensor(mlir::Location loc, - mlir::OpBuilder &builder, std::string name); + mlir::Value EmitInitializerForInputTensor( + mlir::Location loc, mlir::OpBuilder &builder, std::string name); // Get initialized tensor. - onnx::TensorProto& GetInitializedTensor(std::string name) { - assert(nameToInitializedTensor.find(name) != - nameToInitializedTensor.end() && - "Tensor initializer not found"); + onnx::TensorProto &GetInitializedTensor(std::string name) { + assert( + nameToInitializedTensor.find(name) != nameToInitializedTensor.end() && + "Tensor initializer not found"); return nameToInitializedTensor.at(name); } diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index 22a8b57..c7af1ea 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -127,8 +127,8 @@ private: * @param input onnx input tensor ValueInfoProto. * @param symbol mlir input argument. */ - void ImportInputTensorSymbol(const onnx::ValueInfoProto &input, - mlir::Value symbol) { + void ImportInputTensorSymbol( + const onnx::ValueInfoProto &input, mlir::Value symbol) { auto input_tensor_legalized_name = legalize_name(input.name()); assert(!frontend_symbols_.ContainKey(input_tensor_legalized_name) && "Found duplicate legalized input tensor names."); @@ -136,8 +136,7 @@ private: } typedef bstd::variant, float, - std::vector, std::string, - std::vector> + std::vector, std::string, std::vector> AttrValueType; struct ONNXAttrVisitor { @@ -213,8 +212,8 @@ private: llvm_unreachable("Failed to convert attribute proto to name/value pair"); } - std::vector - ImportNodeAttributes(const onnx::NodeProto &node) { + std::vector ImportNodeAttributes( + const onnx::NodeProto &node) { std::vector attributes; for (int i = 0; i < node.attribute_size(); ++i) { auto attr = node.attribute(i); @@ -281,26 +280,25 @@ private: // TODO: Handle optional inputs. auto op = builder_.create(UnknownLoc(), outputTypes, inputs, attributes); for (int i = 0; i < node.output().size(); i++) { - frontend_symbols_.AddMapping(legalize_name(node.output()[i]), - *(op.getODSResults(i).begin())); + frontend_symbols_.AddMapping( + legalize_name(node.output()[i]), *(op.getODSResults(i).begin())); } } template - void buildOperation(const onnx::NodeProto &node, - int expectedNumOperands = -1, - int expectedNumResults = -1) { + void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1, + int expectedNumResults = -1) { std::vector inputs; for (const auto &item : node.input()) if (initializedTensors.ContainKey(legalize_name(item))) { inputs.push_back(initializedTensors.EmitInitializerForInputTensor( - UnknownLoc(), builder_, legalize_name(item))); + UnknownLoc(), builder_, legalize_name(item))); } else if (frontend_symbols_.ContainKey(legalize_name(item))) { inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); } - buildOutputAndOperation(node, inputs, expectedNumOperands, - expectedNumResults); + buildOutputAndOperation( + node, inputs, expectedNumOperands, expectedNumResults); } void ImportNodeReshape(onnx::NodeProto node, int nIn, int nOut) { @@ -310,9 +308,8 @@ private: item = node.input()[i]; // For the second argument, check if there exists an initializer. if (initializedTensors.ContainKey(legalize_name(item))) { - inputs.push_back( - initializedTensors.EmitInitializerForInputTensor( - UnknownLoc(), builder_, legalize_name(item))); + inputs.push_back(initializedTensors.EmitInitializerForInputTensor( + UnknownLoc(), builder_, legalize_name(item))); } else if (frontend_symbols_.ContainKey(legalize_name(item))) { inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); } @@ -372,7 +369,6 @@ private: #if INCLUDE_ONNX_ML == 1 #include "src/Builder/MLOpBuildTable.inc" #endif - } /*! @@ -388,8 +384,8 @@ private: * output tensor. */ void ImportOutputTensor(const onnx::ValueInfoProto &output, - llvm::SmallVectorImpl &ret_types, - llvm::SmallVectorImpl &ret_vals) { + llvm::SmallVectorImpl &ret_types, + llvm::SmallVectorImpl &ret_vals) { auto output_tensor_legalized_name = legalize_name(output.name()); assert(frontend_symbols_.ContainKey(output_tensor_legalized_name) && "Output tensor not found"); @@ -400,8 +396,8 @@ private: ret_vals.push_back(tensor_val); } - void ImportGraph(const onnx::GraphProto &graph, - const std::string &name = "main_graph") { + void ImportGraph( + const onnx::GraphProto &graph, const std::string &name = "main_graph") { // Maintain a mapping between the parameter and its initializer. for (auto initializer : graph.initializer()) { auto name = initializer.name(); @@ -426,8 +422,7 @@ private: // Emit the entry point operation which specifies the number of user // inputs and outputs. - auto entryPoint = mlir::ONNXEntryPointOp::create( - UnknownLoc(), mainFunc, + auto entryPoint = mlir::ONNXEntryPointOp::create(UnknownLoc(), mainFunc, /*numInputs=*/graph.input().size() - graph.initializer().size(), /*numOutputs=*/graph.output().size()); @@ -454,8 +449,8 @@ private: // Create a NoneTyped constant to be used for optional operation inputs // which are not used. - none_ = builder_.create(UnknownLoc(), - builder_.getUnitAttr()); + none_ = + builder_.create(UnknownLoc(), builder_.getUnitAttr()); // Import nodes in the graph. for (const auto &item : graph.node()) { @@ -483,8 +478,7 @@ private: namespace onnx_mlir { void ImportFrontendModelFile(std::string model_fname, - mlir::MLIRContext &context, - mlir::OwningModuleRef &module) { + mlir::MLIRContext &context, mlir::OwningModuleRef &module) { onnx::ModelProto model; std::fstream input(model_fname, std::ios::in | std::ios::binary); diff --git a/src/Builder/FrontendDialectTransformer.hpp b/src/Builder/FrontendDialectTransformer.hpp index fe39675..04eace8 100644 --- a/src/Builder/FrontendDialectTransformer.hpp +++ b/src/Builder/FrontendDialectTransformer.hpp @@ -36,11 +36,10 @@ namespace onnx_mlir { * @return MLIR::module generated for the ONNX model. */ void ImportFrontendModelFile(std::string model_fname, - mlir::MLIRContext &context, - mlir::OwningModuleRef &module); + mlir::MLIRContext &context, mlir::OwningModuleRef &module); /*! * TODO: Import models into other extension dialects that cover the * operations specific to other frameworks such as Tensorflow or Pytorch. */ -} // namespace onnx_mlir +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp index f529a1a..f2e04b7 100644 --- a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp +++ b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp @@ -1,4 +1,5 @@ -//====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering --------===// +//====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering +//--------===// // // Copyright 2019 The IBM Research Authors. // @@ -21,10 +22,9 @@ class ONNXEntryPointLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ONNXEntryPointOp op, - PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, + LogicalResult matchAndRewrite( + ONNXEntryPointOp op, PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getAttrOfType( ONNXEntryPointOp::getEntryPointFuncAttrName()), op.getAttrOfType(ONNXEntryPointOp::getNumInputsAttrName()), @@ -55,8 +55,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() { // We define the specific operations, or dialects, that are legal targets for // this lowering. - target - .addLegalDialect(); + target.addLegalDialect(); // TODO: enable this once more ops are supported. // We also define the ONNX dialect as Illegal so that the conversion will fail @@ -81,8 +80,8 @@ void FrontendToKrnlLoweringPass::runOnOperation() { // Type conversion for function signatures. // Call MLIR FuncOp signature conversion when result type is // a ranked tensor. - populateFuncOpTypeConversionPattern(patterns, &getContext(), - tensor_to_memref_converter); + populateFuncOpTypeConversionPattern( + patterns, &getContext(), tensor_to_memref_converter); // Frontend operation lowering. // Math @@ -119,5 +118,5 @@ std::unique_ptr mlir::createLowerToKrnlPass() { return std::make_unique(); } -static PassRegistration - pass("lower-frontend", "Lower frontend ops to Krnl dialect."); +static PassRegistration pass( + "lower-frontend", "Lower frontend ops to Krnl dialect."); diff --git a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp index 29460ca..7800349 100644 --- a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp @@ -499,9 +499,8 @@ template struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { ONNXElementwiseUnaryOpLowering(MLIRContext *ctx) : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {} - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { // TODO: Check that the types are valid. // An element-wise unary operation must have all operands and the result of // the same type. This should have been verified by the verifier. @@ -566,9 +565,8 @@ template struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { ONNXElementwiseVariadicOpLowering(MLIRContext *ctx) : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {} - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { // TODO: Check that the types are valid. // An element-wise variadic operation must have all operands and the result // of the same type. This should have been verified by the verifier. diff --git a/src/Conversion/ONNXToKrnl/Math/Gemm.cpp b/src/Conversion/ONNXToKrnl/Math/Gemm.cpp index e17e955..1b6bd58 100644 --- a/src/Conversion/ONNXToKrnl/Math/Gemm.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Gemm.cpp @@ -1,4 +1,5 @@ -//===----------------- Gemm.cpp - Lowering Gemm Op -------------------------===// +//===----------------- Gemm.cpp - Lowering Gemm Op +//-------------------------===// // // Copyright 2019 The IBM Research Authors. // @@ -17,9 +18,8 @@ struct ONNXGemmOpLowering : public ConversionPattern { ONNXGemmOpLowering(MLIRContext *ctx) : ConversionPattern(GemmOp::getOperationName(), 1, ctx) {} - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); bool hasBias = !op->getOperand(2).getType().isa(); @@ -32,12 +32,10 @@ struct ONNXGemmOpLowering : public ConversionPattern { auto memRefType = convertToMemRefType(*op->result_type_begin()); - auto alphaAttr = - FloatAttr::get(memRefType.getElementType(), - llvm::dyn_cast(op).alpha().convertToFloat()); - auto betaAttr = - FloatAttr::get(memRefType.getElementType(), - llvm::dyn_cast(op).beta().convertToFloat()); + auto alphaAttr = FloatAttr::get(memRefType.getElementType(), + llvm::dyn_cast(op).alpha().convertToFloat()); + auto betaAttr = FloatAttr::get(memRefType.getElementType(), + llvm::dyn_cast(op).beta().convertToFloat()); auto alpha = rewriter.create(loc, alphaAttr); auto beta = rewriter.create(loc, betaAttr); @@ -101,8 +99,8 @@ struct ONNXGemmOpLowering : public ConversionPattern { optimizedReductionLoops.reserve(1); reductionLoops.push_back(originalLoops[2]); optimizedReductionLoops.push_back(optimizedLoops[2]); - KrnlIterateOperandPack reductionPack(rewriter, reductionLoops, - optimizedReductionLoops); + KrnlIterateOperandPack reductionPack( + rewriter, reductionLoops, optimizedReductionLoops); // Induction variable for the reduction dimension // Try to find and use a static value from A or B first. // If it failed then use a dynamic value. @@ -167,8 +165,8 @@ struct ONNXGemmOpLowering : public ConversionPattern { auto loadedAB = rewriter.create(loc, alloc, loopMNIVs); auto alphaAB = rewriter.create(loc, alpha, loadedAB); if (hasBias) { - auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C, - broadcastedDimInfo); + auto loopCIVs = getLoopIVsForBroadcasting( + loc, rewriter, loopMNIVs, C, broadcastedDimInfo); auto loadedC = rewriter.create(loc, C, loopCIVs); auto betaC = rewriter.create(loc, beta, loadedC); auto Y = rewriter.create(loc, alphaAB, betaC); @@ -214,7 +212,7 @@ struct ONNXGemmOpLowering : public ConversionPattern { } }; -void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns, - MLIRContext *ctx) { +void populateLoweringONNXGemmOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns.insert>(ctx); } diff --git a/src/Conversion/ONNXToKrnl/Math/MatMul.cpp b/src/Conversion/ONNXToKrnl/Math/MatMul.cpp index 39428d7..4a639e7 100644 --- a/src/Conversion/ONNXToKrnl/Math/MatMul.cpp +++ b/src/Conversion/ONNXToKrnl/Math/MatMul.cpp @@ -16,9 +16,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern { ONNXMatMulOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {} - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); ONNXMatMulOpOperandAdaptor operandAdaptor(operands); @@ -119,8 +118,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern { // Define loops for batch dimensions. std::vector originalLoops; std::vector optimizedLoops; - Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, - optimizedLoops, memRefShape.size()); + Block *optimizationBlock = defineLoops( + rewriter, loc, originalLoops, optimizedLoops, memRefShape.size()); // Outer KrnlIterateOp SmallVector loopBatchIVs; @@ -139,8 +138,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern { outerLoops.push_back(originalLoops[i]); optimizedOuterLoops.push_back(optimizedLoops[i]); } - KrnlIterateOperandPack outerPack(rewriter, outerLoops, - optimizedOuterLoops); + KrnlIterateOperandPack outerPack( + rewriter, outerLoops, optimizedOuterLoops); for (int i = 0; i < batchAxes.size(); ++i) { addDimensionToPack(rewriter, loc, outerPack, alloc, i); } @@ -176,11 +175,11 @@ struct ONNXMatMulOpLowering : public ConversionPattern { optimizedMatmulLoops.emplace_back( optimizedLoops[memRefShape.size() - i]); } - KrnlIterateOperandPack matmulPack(rewriter, matmulLoops, - optimizedMatmulLoops); + KrnlIterateOperandPack matmulPack( + rewriter, matmulLoops, optimizedMatmulLoops); for (int i = 2; i > 0; --i) { - addDimensionToPack(rewriter, loc, matmulPack, alloc, - memRefShape.size() - i); + addDimensionToPack( + rewriter, loc, matmulPack, alloc, memRefShape.size() - i); } matmulIterateOp = rewriter.create(loc, matmulPack); } else { @@ -190,10 +189,10 @@ struct ONNXMatMulOpLowering : public ConversionPattern { matmulLoops.emplace_back(originalLoops[memRefShape.size() - 1]); optimizedMatmulLoops.emplace_back( optimizedLoops[memRefShape.size() - 1]); - KrnlIterateOperandPack matmulPack(rewriter, matmulLoops, - optimizedMatmulLoops); - addDimensionToPack(rewriter, loc, matmulPack, alloc, - memRefShape.size() - 1); + KrnlIterateOperandPack matmulPack( + rewriter, matmulLoops, optimizedMatmulLoops); + addDimensionToPack( + rewriter, loc, matmulPack, alloc, memRefShape.size() - 1); matmulIterateOp = rewriter.create(loc, matmulPack); } @@ -230,8 +229,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern { std::vector optimizedReduceLoops; Block *optimizationReduceBlock = defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1); - KrnlIterateOperandPack reducePack(rewriter, reduceLoops, - optimizedReduceLoops); + KrnlIterateOperandPack reducePack( + rewriter, reduceLoops, optimizedReduceLoops); addDimensionToPack(rewriter, loc, reducePack, A, AShape.size() - 1); auto reduceIterateOp = rewriter.create(loc, reducePack); @@ -292,8 +291,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern { std::vector optimizedReduceLoops; Block *optimizationReduceBlock = defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1); - KrnlIterateOperandPack reducePack(rewriter, reduceLoops, - optimizedReduceLoops); + KrnlIterateOperandPack reducePack( + rewriter, reduceLoops, optimizedReduceLoops); addDimensionToPack(rewriter, loc, reducePack, A, 0); auto reduceIterateOp = rewriter.create(loc, reducePack); diff --git a/src/Conversion/ONNXToKrnl/Math/Reduction.cpp b/src/Conversion/ONNXToKrnl/Math/Reduction.cpp index 3520827..a0fe36b 100644 --- a/src/Conversion/ONNXToKrnl/Math/Reduction.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Reduction.cpp @@ -102,9 +102,8 @@ struct ONNXReductionOpLowering : public ConversionPattern { ONNXReductionOpLowering(MLIRContext *ctx) : ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {} - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { /* * Condition: reduction function must be associative and commutative. * diff --git a/src/Conversion/ONNXToKrnl/NN/Conv.cpp b/src/Conversion/ONNXToKrnl/NN/Conv.cpp index aa1e076..51d17c5 100644 --- a/src/Conversion/ONNXToKrnl/NN/Conv.cpp +++ b/src/Conversion/ONNXToKrnl/NN/Conv.cpp @@ -1,4 +1,5 @@ -//===--------------- Conv.cpp - Lowering Convolution Op --------------------===// +//===--------------- Conv.cpp - Lowering Convolution Op +//--------------------===// // // Copyright 2019 The IBM Research Authors. // @@ -175,14 +176,12 @@ struct ONNXConvOpLowering : public ConversionPattern { // Emit the bias, if needed. if (hasBias) { - auto loadResult = - rewriter.create(loc, alloc, resultIndices); + auto loadResult = rewriter.create(loc, alloc, resultIndices); SmallVector biasIndices; biasIndices.emplace_back(kernel); - auto loadBias = - rewriter.create(loc, biasOperand, kernel); - auto resultWithBias = rewriter.create( - loc, loadResult, loadBias); + auto loadBias = rewriter.create(loc, biasOperand, kernel); + auto resultWithBias = + rewriter.create(loc, loadResult, loadBias); // Store initializer value into output location. rewriter.create(loc, resultWithBias, alloc, resultIndices); } diff --git a/src/Conversion/ONNXToKrnl/NN/Pooling.cpp b/src/Conversion/ONNXToKrnl/NN/Pooling.cpp index a276de8..3113d0d 100644 --- a/src/Conversion/ONNXToKrnl/NN/Pooling.cpp +++ b/src/Conversion/ONNXToKrnl/NN/Pooling.cpp @@ -459,7 +459,8 @@ struct ONNXPoolOpLowering : public ConversionPattern { poolDimMap = AffineMap::get(1, 5, dimExpr, rewriter.getContext()); // poolStartMap and poolEndMap - poolStartMap = AffineMap::get(1, 5, {start1, start2}, rewriter.getContext()); + poolStartMap = + AffineMap::get(1, 5, {start1, start2}, rewriter.getContext()); poolEndMap = AffineMap::get(1, 5, {end1, end2}, rewriter.getContext()); } diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp index aac7dc9..c2f5ef3 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp @@ -36,9 +36,7 @@ MemRefType convertToMemRefType(Type type) { /// Insert an allocation and deallocation for the given MemRefType. Value insertAllocAndDealloc(MemRefType type, Location loc, - PatternRewriter &rewriter, - bool insertDealloc, - ArrayRef operands) { + PatternRewriter &rewriter, bool insertDealloc, ArrayRef operands) { // Put together alloc operands for any dynamic dimensions of the memref. AllocOp alloc; if (!operands.empty()) { @@ -64,10 +62,10 @@ Value insertAllocAndDealloc(MemRefType type, Location loc, auto operandDim = rewriter.create(loc, operands[i], operandDimIdx); if (maxDim) { - auto maxCondition = rewriter.create(loc, CmpIPredicate::sgt, - operandDim, maxDim); - maxDim = rewriter.create(loc, maxCondition, operandDim, - maxDim); + auto maxCondition = rewriter.create( + loc, CmpIPredicate::sgt, operandDim, maxDim); + maxDim = rewriter.create( + loc, maxCondition, operandDim, maxDim); } else { maxDim = operandDim; } @@ -122,8 +120,8 @@ bool checkInsertDealloc(Operation *currentOp, int resultIndex) { // Create a mapping from result type's dimensions to input type's dimensions, // given that the result type is the result of a reduction op over the input // type. -std::map -getReductionMapping(MemRefType inputTy, ArrayRef axes, bool keepdims) { +std::map getReductionMapping( + MemRefType inputTy, ArrayRef axes, bool keepdims) { std::map OutInDimMap; int64_t rank = inputTy.getRank(); @@ -152,9 +150,8 @@ getReductionMapping(MemRefType inputTy, ArrayRef axes, bool keepdims) { // Add bounds associated with the op operand to the KRNL iteration pack. // Dynamic dimenions are supported. -void addDimensionToPack(ConversionPatternRewriter &rewriter, - Location loc, KrnlIterateOperandPack &pack, - Value operand, int index) { +void addDimensionToPack(ConversionPatternRewriter &rewriter, Location loc, + KrnlIterateOperandPack &pack, Value operand, int index) { auto shape = operand.getType().cast().getShape(); if (shape[index] < 0) { pack.pushConstantBound(0); @@ -168,10 +165,9 @@ void addDimensionToPack(ConversionPatternRewriter &rewriter, // Function that defines the KRNL dialect loops and their respective // optimized version. -KrnlOptimizeLoopsOp -emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc, - std::vector &loops, - std::vector &optimizedLoops, int64_t numLoops) { +KrnlOptimizeLoopsOp emitOptimizedLoops(ConversionPatternRewriter &rewriter, + Location loc, std::vector &loops, std::vector &optimizedLoops, + int64_t numLoops) { // Define loops. auto loopsOp = rewriter.create(loc, numLoops); loops.reserve(numLoops); @@ -190,9 +186,8 @@ emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc, // Function that emits the loops and their optimized version. // The function returns a reference to the inner optimization block. Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc, - std::vector &loops, - std::vector &optimizedLoops, - int64_t numLoops) { + std::vector &loops, std::vector &optimizedLoops, + int64_t numLoops) { KrnlOptimizeLoopsOp optimizedLoopsOp = emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops); return &optimizedLoopsOp.region().front(); @@ -201,10 +196,9 @@ Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc, // Function which emits a basic set of loops and optimized loops // for a given operation argument. A reference to the loop optimization // block is returned in the last argument of the function. -void emitKrnlLoopsAndIterationForOperand( - ConversionPatternRewriter &rewriter, Location loc, Value operand, - std::vector &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp, - KrnlIterateOp &iterateOp) { +void emitKrnlLoopsAndIterationForOperand(ConversionPatternRewriter &rewriter, + Location loc, Value operand, std::vector &originalLoops, + KrnlOptimizeLoopsOp &optimizedLoopsOp, KrnlIterateOp &iterateOp) { // Operand shape. auto shape = operand.getType().cast().getShape(); @@ -240,9 +234,9 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { // Get run-time dimension information for unknown dimensions used for // broadcasting. -std::map> -getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, - MemRefType memRefType, ArrayRef operands) { +std::map> getBroadcastedDimInfo(Location loc, + ConversionPatternRewriter &rewriter, MemRefType memRefType, + ArrayRef operands) { auto memRefShape = memRefType.getShape(); int64_t rank = memRefShape.size(); // For unknown dimensions, we need to get dimension values at runtime in @@ -286,10 +280,9 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, // Extract induction variables that are used for broadcasting values of a // given operand. -std::vector -getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, - ArrayRef loopIVs, Value operand, - std::map broadcastedDims) { +std::vector getLoopIVsForBroadcasting(Location loc, + ConversionPatternRewriter &rewriter, ArrayRef loopIVs, Value operand, + std::map broadcastedDims) { // `operand` must has a ranked type. This should have been checked by the // shape inference pass. auto operandShape = operand.getType().cast().getShape(); @@ -310,8 +303,8 @@ getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, // If its value is 1, it is broadcasted dimension. // Otherwise, non-broadcasted dimension. auto zero = rewriter.create(loc, 0); - auto idx = rewriter.create(loc, broadcastedDims[dimIdx], zero, - loopIVs[loopIdx]); + auto idx = rewriter.create( + loc, broadcastedDims[dimIdx], zero, loopIVs[loopIdx]); newLoopIVs.insert(newLoopIVs.begin(), idx); } else { // Non-broadcasted dimension diff --git a/src/Conversion/ONNXToKrnl/Tensor/Concat.cpp b/src/Conversion/ONNXToKrnl/Tensor/Concat.cpp index 4511fc0..85f3b82 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Concat.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Concat.cpp @@ -30,7 +30,7 @@ struct ONNXConcatOpLowering : public ConversionPattern { auto memRefType = convertToMemRefType(*op->result_type_begin()); auto resultShape = memRefType.getShape(); auto rank = resultShape.size(); - assert((axis >=0 && axis < rank) && "Concat axis out of bounds"); + assert((axis >= 0 && axis < rank) && "Concat axis out of bounds"); if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); diff --git a/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp b/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp index e5105d1..709bd56 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp @@ -17,8 +17,8 @@ struct ONNXConstantOpLowering : public ConversionPattern { ONNXConstantOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXConstantOp::getOperationName(), 1, ctx) { - constantID = 0; - } + constantID = 0; + } LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { @@ -34,12 +34,11 @@ struct ONNXConstantOpLowering : public ConversionPattern { // Shape based computations. auto shape = memRefType.getShape(); int64_t numElements = 1; - for (int i=0; i(loc, - memRefType, + auto constantGlobal = rewriter.create(loc, memRefType, rewriter.getI64ArrayAttr(shape), rewriter.getStringAttr("constant_" + std::to_string(constantID)), constantOp.value().getValue()); diff --git a/src/Dialect/Krnl/KrnlTypes.hpp b/src/Dialect/Krnl/KrnlTypes.hpp index 159195f..db75855 100644 --- a/src/Dialect/Krnl/KrnlTypes.hpp +++ b/src/Dialect/Krnl/KrnlTypes.hpp @@ -24,15 +24,15 @@ enum Kinds { } class LoopType : public mlir::Type::TypeBase { - public: +public: using Base::Base; // Support type inquiry through isa, cast and dyn_cast. static bool kindof(unsigned kind) { return kind == KrnlTypes::Loop; } // Get a unique instance of Loop type. - static LoopType get(mlir::MLIRContext* context) { + static LoopType get(mlir::MLIRContext *context) { return Base::get(context, KrnlTypes::Loop); } }; -} // namespace mlir +} // namespace mlir diff --git a/src/Dialect/MLONNX/MLONNXOps.cpp b/src/Dialect/MLONNX/MLONNXOps.cpp index 02d5ef1..0b5c6ec 100644 --- a/src/Dialect/MLONNX/MLONNXOps.cpp +++ b/src/Dialect/MLONNX/MLONNXOps.cpp @@ -39,7 +39,6 @@ MLONNXOpsDialect::MLONNXOpsDialect(mlir::MLIRContext *ctx) >(); } - //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/src/Dialect/MLONNX/MLONNXOps.hpp b/src/Dialect/MLONNX/MLONNXOps.hpp index 380a86a..3859756 100644 --- a/src/Dialect/MLONNX/MLONNXOps.hpp +++ b/src/Dialect/MLONNX/MLONNXOps.hpp @@ -19,14 +19,14 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/StandardTypes.h" -#include "src/Interface/ShapeInferenceInterface.hpp" #include "src/Interface/PromotableConstOperandsOpInterface.hpp" +#include "src/Interface/ShapeInferenceInterface.hpp" namespace mlir { class MLONNXOpsDialect : public Dialect { - public: - MLONNXOpsDialect(MLIRContext* context); +public: + MLONNXOpsDialect(MLIRContext *context); /// Provide a utility accessor to the dialect namespace. This is used by /// several utilities for casting between dialects. @@ -38,6 +38,6 @@ class MLONNXOpsDialect : public Dialect { #define GET_OP_CLASSES #include "src/Dialect/MLONNX/MLONNXOps.hpp.inc" -} // end namespace mlir +} // end namespace mlir namespace onnx_mlir {} diff --git a/src/Dialect/ONNX/ONNXOps.hpp b/src/Dialect/ONNX/ONNXOps.hpp index ab19f26..66c0e7e 100644 --- a/src/Dialect/ONNX/ONNXOps.hpp +++ b/src/Dialect/ONNX/ONNXOps.hpp @@ -19,14 +19,14 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/StandardTypes.h" -#include "src/Interface/ShapeInferenceInterface.hpp" #include "src/Interface/PromotableConstOperandsOpInterface.hpp" +#include "src/Interface/ShapeInferenceInterface.hpp" namespace mlir { class ONNXOpsDialect : public Dialect { - public: - ONNXOpsDialect(MLIRContext* context); +public: + ONNXOpsDialect(MLIRContext *context); /// Provide a utility accessor to the dialect namespace. This is used by /// several utilities for casting between dialects. @@ -38,6 +38,6 @@ class ONNXOpsDialect : public Dialect { #define GET_OP_CLASSES #include "src/Dialect/ONNX/ONNXOps.hpp.inc" -} // end namespace mlir +} // end namespace mlir namespace onnx_mlir {} diff --git a/src/Interface/PromotableConstOperandsOpInterface.cpp b/src/Interface/PromotableConstOperandsOpInterface.cpp index 92e20e2..03d4df6 100644 --- a/src/Interface/PromotableConstOperandsOpInterface.cpp +++ b/src/Interface/PromotableConstOperandsOpInterface.cpp @@ -10,7 +10,6 @@ // //===----------------------------------------------------------------------===// - #include "src/Interface/PromotableConstOperandsOpInterface.hpp" using namespace mlir; @@ -20,4 +19,3 @@ using namespace mlir; //===----------------------------------------------------------------------===// #include "src/Interface/PromotableConstOperandsOpInterface.cpp.inc" - diff --git a/src/Interface/PromotableConstOperandsOpInterface.hpp b/src/Interface/PromotableConstOperandsOpInterface.hpp index a105fcb..d8ba72e 100644 --- a/src/Interface/PromotableConstOperandsOpInterface.hpp +++ b/src/Interface/PromotableConstOperandsOpInterface.hpp @@ -22,4 +22,4 @@ namespace mlir { /// Include the auto-generated declarations. #include "src/Interface/PromotableConstOperandsOpInterface.hpp.inc" -} // end namespace mlir \ No newline at end of file +} // end namespace mlir \ No newline at end of file diff --git a/src/Interface/ShapeInferenceInterface.cpp b/src/Interface/ShapeInferenceInterface.cpp index 3b17e1f..eb5fe2e 100644 --- a/src/Interface/ShapeInferenceInterface.cpp +++ b/src/Interface/ShapeInferenceInterface.cpp @@ -16,4 +16,4 @@ namespace mlir { /// Include the auto-generated declarations. #include "src/Interface/ShapeInference.cpp.inc" -} // end namespace mlir +} // end namespace mlir diff --git a/src/Interface/ShapeInferenceInterface.hpp b/src/Interface/ShapeInferenceInterface.hpp index 812a7b3..d1badc2 100644 --- a/src/Interface/ShapeInferenceInterface.hpp +++ b/src/Interface/ShapeInferenceInterface.hpp @@ -18,4 +18,4 @@ namespace mlir { /// Include the auto-generated declarations. #include "src/Interface/ShapeInference.hpp.inc" -} // end namespace mlir +} // end namespace mlir diff --git a/src/MainUtils.cpp b/src/MainUtils.cpp index ecf5495..c806e4e 100644 --- a/src/MainUtils.cpp +++ b/src/MainUtils.cpp @@ -14,7 +14,7 @@ #ifdef _WIN32 #include -#else +#else #include #endif @@ -22,7 +22,7 @@ using namespace std; using namespace onnx_mlir; void LoadMLIR(string inputFilename, mlir::MLIRContext &context, - mlir::OwningModuleRef &module) { + mlir::OwningModuleRef &module) { // Handle '.mlir' input to the ONNX MLIR frontend. // The mlir format indicates that one or more of the supported // representations are used in the file. @@ -46,10 +46,10 @@ void LoadMLIR(string inputFilename, mlir::MLIRContext &context, void EmitLLVMBitCode( const mlir::OwningModuleRef &module, string outputFilename) { error_code error; - llvm::raw_fd_ostream moduleBitcodeStream(outputFilename, error, - llvm::sys::fs::F_None); - llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module), - moduleBitcodeStream); + llvm::raw_fd_ostream moduleBitcodeStream( + outputFilename, error, llvm::sys::fs::F_None); + llvm::WriteBitcodeToFile( + *mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream); moduleBitcodeStream.flush(); } @@ -90,7 +90,7 @@ void addKrnlToLLVMPasses(mlir::PassManager &pm) { } void processInputFile(string inputFilename, EmissionTargetType emissionTarget, - mlir::MLIRContext &context, mlir::OwningModuleRef &module) { + mlir::MLIRContext &context, mlir::OwningModuleRef &module) { // Decide if the input file is an ONNX model or a model specified // in MLIR. The extension of the file is the decider. string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1); @@ -99,7 +99,6 @@ void processInputFile(string inputFilename, EmissionTargetType emissionTarget, assert(inputIsONNX != inputIsMLIR && "Either ONNX model or MLIR file needs to be provided."); - if (inputIsONNX) { ImportFrontendModelFile(inputFilename, context, module); } else { @@ -119,8 +118,8 @@ void outputCode( module->dump(); fflush(stderr); // set modified stderr as original stderr - _dup2(stderrOrigin, _fileno( stderr )); -#else + _dup2(stderrOrigin, _fileno(stderr)); +#else if (fork() == 0) { freopen(tempFilename.c_str(), "w", stderr); module->dump(); @@ -151,7 +150,7 @@ void emitOutputFiles(string outputBaseName, EmissionTargetType emissionTarget, // necessary when emitting the .bc file. if (emissionTarget == EmitLLVMBC) { // Write LLVM bitcode to disk. - string outputFilename = outputBaseName + ".bc"; + string outputFilename = outputBaseName + ".bc"; EmitLLVMBitCode(module, outputFilename); printf("LLVM bitcode written to %s\n", outputFilename.c_str()); } else { diff --git a/src/MainUtils.hpp b/src/MainUtils.hpp index bf96cea..6eb8802 100644 --- a/src/MainUtils.hpp +++ b/src/MainUtils.hpp @@ -28,9 +28,9 @@ #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/OptUtils.h" -#include "mlir/InitAllDialects.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" +#include "mlir/InitAllDialects.h" #include "mlir/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -46,10 +46,10 @@ enum EmissionTargetType { }; void LoadMLIR(std::string inputFilename, mlir::MLIRContext &context, - mlir::OwningModuleRef &module); + mlir::OwningModuleRef &module); void EmitLLVMBitCode( - const mlir::OwningModuleRef &module, std::string outputFilename); + const mlir::OwningModuleRef &module, std::string outputFilename); void registerDialects(); @@ -66,8 +66,7 @@ void processInputFile(std::string inputFilename, mlir::OwningModuleRef &module); void outputCode( - mlir::OwningModuleRef &module, std::string filename, - std::string extension); + mlir::OwningModuleRef &module, std::string filename, std::string extension); void emitOutputFiles(std::string outputBaseName, EmissionTargetType emissionTarget, mlir::MLIRContext &context, diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index 4048932..8ccb537 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -38,4 +38,4 @@ std::unique_ptr createElideConstGlobalValuePass(); /// Pass for lowering Krnl dialect to LLVM dialect. std::unique_ptr createKrnlLowerToLLVMPass(); -} // end namespace mlir +} // end namespace mlir diff --git a/src/Runtime/DataType.h b/src/Runtime/DataType.h index 9631318..c9cc174 100644 --- a/src/Runtime/DataType.h +++ b/src/Runtime/DataType.h @@ -1,15 +1,15 @@ enum DYN_MEMREF_DATA_TYPE { UNDEFINED = 0; // Basic types. - FLOAT = 1; // float - UINT8 = 2; // uint8_t - INT8 = 3; // int8_t - UINT16 = 4; // uint16_t - INT16 = 5; // int16_t - INT32 = 6; // int32_t - INT64 = 7; // int64_t - STRING = 8; // string - BOOL = 9; // bool + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool // IEEE754 half-precision floating-point format (16 bits wide). // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. @@ -18,8 +18,8 @@ enum DYN_MEMREF_DATA_TYPE { DOUBLE = 11; UINT32 = 12; UINT64 = 13; - COMPLEX64 = 14; // complex with float32 real and imaginary components - COMPLEX128 = 15; // complex with float64 real and imaginary components + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components // Non-IEEE floating-point format based on IEEE754 single-precision // floating-point number truncated to 16 bits. diff --git a/src/Runtime/DynMemRef.cpp b/src/Runtime/DynMemRef.cpp index da299d1..b6c9b8c 100644 --- a/src/Runtime/DynMemRef.cpp +++ b/src/Runtime/DynMemRef.cpp @@ -33,8 +33,8 @@ DynMemRef *getDynMemRef(OrderedDynMemRefDict *tensorDict, int idx) { return tensorDict->tensorDict[tensorDict->orderedNames[idx]]; } -void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx, - DynMemRef *tensor) { +void setDynMemRef( + OrderedDynMemRefDict *tensorDict, int idx, DynMemRef *tensor) { if (tensorDict->orderedNames.size() <= idx) tensorDict->orderedNames.resize(idx + 1); diff --git a/src/Runtime/DynMemRef.h b/src/Runtime/DynMemRef.h index 6d5bef6..0ccae32 100644 --- a/src/Runtime/DynMemRef.h +++ b/src/Runtime/DynMemRef.h @@ -38,7 +38,6 @@ typedef struct _OrderedDynMemRefDict OrderedDynMemRefDict; #ifdef __cplusplus extern "C" { #endif - // Get number of dynamic memrefs in OrderedDynMemRefDict dict. int numDynMemRefs(OrderedDynMemRefDict *dict); @@ -53,8 +52,8 @@ DynMemRef *createDynMemRef(int rank); DynMemRef *getDynMemRef(OrderedDynMemRefDict *orderedDict, int i); // Set the i-th dynmemref in orderedDict to be dynMemRef. -void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx, - DynMemRef *dynMemRef); +void setDynMemRef( + OrderedDynMemRefDict *tensorDict, int idx, DynMemRef *dynMemRef); // Get data pointer from dynMemRef. void *getData(DynMemRef *dynMemRef); diff --git a/src/Runtime/Runtime.cpp b/src/Runtime/Runtime.cpp index 3f177a9..1d21071 100644 --- a/src/Runtime/Runtime.cpp +++ b/src/Runtime/Runtime.cpp @@ -1,14 +1,14 @@ #include "Runtime.hpp" -ExecutionSession::ExecutionSession(std::string sharedLibPath, - std::string entryPointName) { +ExecutionSession::ExecutionSession( + std::string sharedLibPath, std::string entryPointName) { _sharedLibraryHandle = dlopen(sharedLibPath.c_str(), RTLD_LAZY); _entryPointFunc = (entryPointFuncType)dlsym(_sharedLibraryHandle, entryPointName.c_str()); } -std::vector -ExecutionSession::run(std::vector inputsPyArray) { +std::vector ExecutionSession::run( + std::vector inputsPyArray) { assert(_entryPointFunc && "entry point not loaded"); auto *wrappedInput = createOrderedDynMemRefDict(); int inputIdx = 0; @@ -40,8 +40,8 @@ ExecutionSession::run(std::vector inputsPyArray) { auto *wrappedOutput = _entryPointFunc(wrappedInput); for (int i = 0; i < numDynMemRefs(wrappedOutput); i++) { auto *dynMemRef = getDynMemRef(wrappedOutput, i); - auto shape = std::vector(dynMemRef->sizes, - dynMemRef->sizes + dynMemRef->rank); + auto shape = std::vector( + dynMemRef->sizes, dynMemRef->sizes + dynMemRef->rank); outputPyArrays.emplace_back( py::array(py::dtype("float32"), shape, dynMemRef->data)); } diff --git a/src/Transform/LowerToLLVM.cpp b/src/Transform/LowerToLLVM.cpp index 666955f..1a79a8c 100644 --- a/src/Transform/LowerToLLVM.cpp +++ b/src/Transform/LowerToLLVM.cpp @@ -144,9 +144,9 @@ public: assert(krnlGlobalOp.value().hasValue() && "Krnl Global must always have a value"); - global = rewriter.create(loc, - llvmGlobalType, /*isConstant=*/true, - LLVM::Linkage::Internal, name, krnlGlobalOp.value().getValue()); + global = rewriter.create(loc, llvmGlobalType, + /*isConstant=*/true, LLVM::Linkage::Internal, name, + krnlGlobalOp.value().getValue()); } // Some frequently used types. diff --git a/src/Transform/ONNX/AttributePromotion.cpp b/src/Transform/ONNX/AttributePromotion.cpp index 8420740..35721cd 100644 --- a/src/Transform/ONNX/AttributePromotion.cpp +++ b/src/Transform/ONNX/AttributePromotion.cpp @@ -75,7 +75,7 @@ public: OwningRewritePatternList patterns; auto *context = &getContext(); ConstantOp::getCanonicalizationPatterns(patterns, context); - applyPatternsAndFoldGreedily(f, patterns); + applyPatternsAndFoldGreedily(f, patterns); } }; } // end anonymous namespace diff --git a/src/Transform/ONNX/ONNXCombine.cpp b/src/Transform/ONNX/ONNXCombine.cpp index b549f0d..60696e0 100644 --- a/src/Transform/ONNX/ONNXCombine.cpp +++ b/src/Transform/ONNX/ONNXCombine.cpp @@ -12,35 +12,35 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" -#include #include "src/Dialect/ONNX/ONNXOps.hpp" +#include using namespace mlir; namespace { /// Include the patterns defined in the Declarative Rewrite framework. #include "src/Transform/ONNX/ONNXCombine.inc" -} // end anonymous namespace +} // end anonymous namespace /// Register optimization patterns as "canonicalization" patterns /// on the ONNXMatMultOp. void ONNXAddOp::getCanonicalizationPatterns( - OwningRewritePatternList& results, MLIRContext* context) { + OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } void ONNXGemmOp::getCanonicalizationPatterns( - OwningRewritePatternList& results, MLIRContext* context) { - results.insert(context); + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); } /// on the ONNXIdentityOp. void ONNXIdentityOp::getCanonicalizationPatterns( - OwningRewritePatternList& results, MLIRContext* context) { + OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } -///on the ONNXPadConstantValueOp. +/// on the ONNXPadConstantValueOp. void ONNXPadConstantValueOp::getCanonicalizationPatterns( - OwningRewritePatternList& result, MLIRContext* context) { + OwningRewritePatternList &result, MLIRContext *context) { result.insert(context); } diff --git a/src/Transform/ONNX/ONNXDecompose.cpp b/src/Transform/ONNX/ONNXDecompose.cpp index 78dc07a..2032344 100644 --- a/src/Transform/ONNX/ONNXDecompose.cpp +++ b/src/Transform/ONNX/ONNXDecompose.cpp @@ -27,7 +27,8 @@ namespace { /// Include the patterns defined in the Declarative Rewrite framework. #include "src/Transform/ONNX/ONNXDecompose.inc" -struct DecomposeONNXToONNXPass : public PassWrapper { +struct DecomposeONNXToONNXPass + : public PassWrapper { void runOnFunction() final; }; } // end anonymous namespace. diff --git a/src/Transform/ONNX/ShapeInferencePass.cpp b/src/Transform/ONNX/ShapeInferencePass.cpp index 9b81db7..e18f8fe 100644 --- a/src/Transform/ONNX/ShapeInferencePass.cpp +++ b/src/Transform/ONNX/ShapeInferencePass.cpp @@ -9,10 +9,10 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/StandardTypes.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/IR/StandardTypes.h" #include "src/Interface/ShapeInferenceInterface.hpp" #include "src/Pass/Passes.hpp" @@ -25,7 +25,8 @@ namespace { * candidate operations and propagating the shape information until the list * of operations is empty [credit MLIR authors]. */ -class ShapeInferencePass : public mlir::PassWrapper { +class ShapeInferencePass + : public mlir::PassWrapper { public: void runOnFunction() override { auto f = getFunction(); @@ -63,8 +64,7 @@ public: if (auto terminator_op = f.getBody().back().getTerminator()) { auto results = terminator_op->getOperandTypes(); - f.setType(FunctionType::get( - f.getType().getInputs(), + f.setType(FunctionType::get(f.getType().getInputs(), std::vector(results.begin(), results.end()), f.getContext())); } } @@ -146,5 +146,5 @@ std::unique_ptr mlir::createShapeInferencePass() { return std::make_unique(); } -static PassRegistration - pass("shape-inference", "Shape inference for frontend dialects."); +static PassRegistration pass( + "shape-inference", "Shape inference for frontend dialects."); diff --git a/src/main.cpp b/src/main.cpp index 7cc57a1..80faeda 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -14,30 +14,30 @@ using namespace onnx_mlir; int main(int argc, char *argv[]) { registerDialects(); - llvm::cl::OptionCategory OnnxMlirOptions("ONNX MLIR Options", - "These are frontend options."); - llvm::cl::opt inputFilename( - llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::init("-"), + llvm::cl::OptionCategory OnnxMlirOptions( + "ONNX MLIR Options", "These are frontend options."); + llvm::cl::opt inputFilename(llvm::cl::Positional, + llvm::cl::desc(""), llvm::cl::init("-"), llvm::cl::cat(OnnxMlirOptions)); llvm::cl::opt emissionTarget( llvm::cl::desc("Choose target to emit:"), llvm::cl::values( clEnumVal(EmitONNXBasic, - "Ingest ONNX and emit the basic ONNX operations without" - "inferred shapes."), - clEnumVal(EmitONNXIR, - "Ingest ONNX and emit corresponding ONNX dialect."), - clEnumVal(EmitMLIR, - "Lower model to MLIR built-in transformation dialect."), + "Ingest ONNX and emit the basic ONNX operations without" + "inferred shapes."), + clEnumVal( + EmitONNXIR, "Ingest ONNX and emit corresponding ONNX dialect."), + clEnumVal( + EmitMLIR, "Lower model to MLIR built-in transformation dialect."), clEnumVal(EmitLLVMIR, "Lower model to LLVM IR (LLVM dialect)."), clEnumVal(EmitLLVMBC, "Lower model to LLVM IR and emit (to file) " "LLVM bitcode for model.")), llvm::cl::init(EmitLLVMBC), llvm::cl::cat(OnnxMlirOptions)); llvm::cl::HideUnrelatedOptions(OnnxMlirOptions); - llvm::cl::ParseCommandLineOptions(argc, argv, - "ONNX MLIR modular optimizer driver\n"); + llvm::cl::ParseCommandLineOptions( + argc, argv, "ONNX MLIR modular optimizer driver\n"); mlir::MLIRContext context; mlir::OwningModuleRef module;