[NFC] Set up clang-format Github Action (#119)

* Run clang-format on all source code.

* Add Clang-Format Github Action.

* Apply patch produced by Clang-Format Bot.

* nit.

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Tian Jin 2020-05-13 22:37:51 +08:00 committed by GitHub
parent 24343177b8
commit 7f2bffb27d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 272 additions and 260 deletions

40
.github/workflows/main.yml vendored Normal file
View File

@ -0,0 +1,40 @@
# This is a basic workflow to help you get started with Actions
name: Clang-Format Bot
# Controls when the action will run. Triggers the workflow on push or pull request
# events but only for the master branch
on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
# A workflow run is made up of one or more jobs that can run sequentially or in parallel
jobs:
# This workflow contains a single job called "build"
build:
# The type of runner that the job will run on
runs-on: ubuntu-latest
# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2
- name: clang-format lint
uses: DoozyX/clang-format-lint-action@v0.5
with:
# Source folder to check formatting
source: ./src
# Version of clang-format
clangFormatVersion: 9 # optional, default is 9
# Runs a single command using the runners shell
- name: Run a one-line script
run: echo Hello, world!
# Runs a set of commands using the runners shell
- name: Run a multi-line script
run: |
echo Add other actions to build,
echo test, and deploy your project.

View File

@ -12,8 +12,8 @@
namespace onnx_mlir { namespace onnx_mlir {
void replaceAll(std::string &str, const std::string &from, void replaceAll(
const std::string &to) { std::string &str, const std::string &from, const std::string &to) {
if (from.empty()) if (from.empty())
return; return;
size_t start_pos = 0; size_t start_pos = 0;
@ -121,7 +121,6 @@ void InitializedTensorMapping::AddMapping(
nameToInitializedTensor.emplace(name, tensor); nameToInitializedTensor.emplace(name, tensor);
} }
bool InitializedTensorMapping::ContainKey(std::string name) { bool InitializedTensorMapping::ContainKey(std::string name) {
return nameToInitializedTensor.count(name) != 0; return nameToInitializedTensor.count(name) != 0;
} }
@ -132,8 +131,8 @@ mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor(
onnx::TensorProto initializer = GetInitializedTensor(name); onnx::TensorProto initializer = GetInitializedTensor(name);
// Tensor dimensions. // Tensor dimensions.
llvm::ArrayRef<int64_t> tensorDims(initializer.dims().data(), llvm::ArrayRef<int64_t> tensorDims(
initializer.dims().size()); initializer.dims().data(), initializer.dims().size());
// Emit ConstantOp and record the mapping between the input and // Emit ConstantOp and record the mapping between the input and
// the constant value. // the constant value.
@ -143,8 +142,7 @@ mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor(
mlir::ShapedType tensorType; mlir::ShapedType tensorType;
switch (initializer.data_type()) { switch (initializer.data_type()) {
case (onnx::TensorProto::FLOAT): { case (onnx::TensorProto::FLOAT): {
const auto& arrayAttrInitializer = const auto &arrayAttrInitializer = CreateArrayAttribute<float>(initializer);
CreateArrayAttribute<float>(initializer);
elementType = builder.getF32Type(); elementType = builder.getF32Type();
tensorType = mlir::RankedTensorType::get(tensorDims, elementType); tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
constantDenseAttribute = mlir::DenseElementsAttr::get( constantDenseAttribute = mlir::DenseElementsAttr::get(

View File

@ -20,8 +20,8 @@
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h" #include "mlir/IR/Function.h"
#include "mlir/IR/Location.h" #include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h" #include "mlir/IR/Module.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
@ -37,11 +37,12 @@
#endif #endif
#include "onnx/onnx_pb.h" #include "onnx/onnx_pb.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir { namespace onnx_mlir {
void replaceAll(std::string &str, const std::string &from, void replaceAll(
const std::string &to); std::string &str, const std::string &from, const std::string &to);
std::string legalize_name(std::string name); std::string legalize_name(std::string name);
@ -86,13 +87,13 @@ struct InitializedTensorMapping {
// This will allow the propagation of shape information passed in as an // This will allow the propagation of shape information passed in as an
// argument to operations such as Reshape and will enable other // argument to operations such as Reshape and will enable other
// optimizations such as constant folding. // optimizations such as constant folding.
mlir::Value EmitInitializerForInputTensor(mlir::Location loc, mlir::Value EmitInitializerForInputTensor(
mlir::OpBuilder &builder, std::string name); mlir::Location loc, mlir::OpBuilder &builder, std::string name);
// Get initialized tensor. // Get initialized tensor.
onnx::TensorProto &GetInitializedTensor(std::string name) { onnx::TensorProto &GetInitializedTensor(std::string name) {
assert(nameToInitializedTensor.find(name) != assert(
nameToInitializedTensor.end() && nameToInitializedTensor.find(name) != nameToInitializedTensor.end() &&
"Tensor initializer not found"); "Tensor initializer not found");
return nameToInitializedTensor.at(name); return nameToInitializedTensor.at(name);
} }

View File

@ -127,8 +127,8 @@ private:
* @param input onnx input tensor ValueInfoProto. * @param input onnx input tensor ValueInfoProto.
* @param symbol mlir input argument. * @param symbol mlir input argument.
*/ */
void ImportInputTensorSymbol(const onnx::ValueInfoProto &input, void ImportInputTensorSymbol(
mlir::Value symbol) { const onnx::ValueInfoProto &input, mlir::Value symbol) {
auto input_tensor_legalized_name = legalize_name(input.name()); auto input_tensor_legalized_name = legalize_name(input.name());
assert(!frontend_symbols_.ContainKey(input_tensor_legalized_name) && assert(!frontend_symbols_.ContainKey(input_tensor_legalized_name) &&
"Found duplicate legalized input tensor names."); "Found duplicate legalized input tensor names.");
@ -136,8 +136,7 @@ private:
} }
typedef bstd::variant<int64_t, std::vector<int64_t>, float, typedef bstd::variant<int64_t, std::vector<int64_t>, float,
std::vector<float>, std::string, std::vector<float>, std::string, std::vector<std::string>>
std::vector<std::string>>
AttrValueType; AttrValueType;
struct ONNXAttrVisitor { struct ONNXAttrVisitor {
@ -213,8 +212,8 @@ private:
llvm_unreachable("Failed to convert attribute proto to name/value pair"); llvm_unreachable("Failed to convert attribute proto to name/value pair");
} }
std::vector<mlir::NamedAttribute> std::vector<mlir::NamedAttribute> ImportNodeAttributes(
ImportNodeAttributes(const onnx::NodeProto &node) { const onnx::NodeProto &node) {
std::vector<mlir::NamedAttribute> attributes; std::vector<mlir::NamedAttribute> attributes;
for (int i = 0; i < node.attribute_size(); ++i) { for (int i = 0; i < node.attribute_size(); ++i) {
auto attr = node.attribute(i); auto attr = node.attribute(i);
@ -281,14 +280,13 @@ private:
// TODO: Handle optional inputs. // TODO: Handle optional inputs.
auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes); auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
for (int i = 0; i < node.output().size(); i++) { for (int i = 0; i < node.output().size(); i++) {
frontend_symbols_.AddMapping(legalize_name(node.output()[i]), frontend_symbols_.AddMapping(
*(op.getODSResults(i).begin())); legalize_name(node.output()[i]), *(op.getODSResults(i).begin()));
} }
} }
template <typename T> template <typename T>
void buildOperation(const onnx::NodeProto &node, void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1,
int expectedNumOperands = -1,
int expectedNumResults = -1) { int expectedNumResults = -1) {
std::vector<mlir::Value> inputs; std::vector<mlir::Value> inputs;
for (const auto &item : node.input()) for (const auto &item : node.input())
@ -299,8 +297,8 @@ private:
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
} }
buildOutputAndOperation<T>(node, inputs, expectedNumOperands, buildOutputAndOperation<T>(
expectedNumResults); node, inputs, expectedNumOperands, expectedNumResults);
} }
void ImportNodeReshape(onnx::NodeProto node, int nIn, int nOut) { void ImportNodeReshape(onnx::NodeProto node, int nIn, int nOut) {
@ -310,8 +308,7 @@ private:
item = node.input()[i]; item = node.input()[i];
// For the second argument, check if there exists an initializer. // For the second argument, check if there exists an initializer.
if (initializedTensors.ContainKey(legalize_name(item))) { if (initializedTensors.ContainKey(legalize_name(item))) {
inputs.push_back( inputs.push_back(initializedTensors.EmitInitializerForInputTensor(
initializedTensors.EmitInitializerForInputTensor(
UnknownLoc(), builder_, legalize_name(item))); UnknownLoc(), builder_, legalize_name(item)));
} else if (frontend_symbols_.ContainKey(legalize_name(item))) { } else if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
@ -372,7 +369,6 @@ private:
#if INCLUDE_ONNX_ML == 1 #if INCLUDE_ONNX_ML == 1
#include "src/Builder/MLOpBuildTable.inc" #include "src/Builder/MLOpBuildTable.inc"
#endif #endif
} }
/*! /*!
@ -400,8 +396,8 @@ private:
ret_vals.push_back(tensor_val); ret_vals.push_back(tensor_val);
} }
void ImportGraph(const onnx::GraphProto &graph, void ImportGraph(
const std::string &name = "main_graph") { const onnx::GraphProto &graph, const std::string &name = "main_graph") {
// Maintain a mapping between the parameter and its initializer. // Maintain a mapping between the parameter and its initializer.
for (auto initializer : graph.initializer()) { for (auto initializer : graph.initializer()) {
auto name = initializer.name(); auto name = initializer.name();
@ -426,8 +422,7 @@ private:
// Emit the entry point operation which specifies the number of user // Emit the entry point operation which specifies the number of user
// inputs and outputs. // inputs and outputs.
auto entryPoint = mlir::ONNXEntryPointOp::create( auto entryPoint = mlir::ONNXEntryPointOp::create(UnknownLoc(), mainFunc,
UnknownLoc(), mainFunc,
/*numInputs=*/graph.input().size() - graph.initializer().size(), /*numInputs=*/graph.input().size() - graph.initializer().size(),
/*numOutputs=*/graph.output().size()); /*numOutputs=*/graph.output().size());
@ -454,8 +449,8 @@ private:
// Create a NoneTyped constant to be used for optional operation inputs // Create a NoneTyped constant to be used for optional operation inputs
// which are not used. // which are not used.
none_ = builder_.create<mlir::ConstantOp>(UnknownLoc(), none_ =
builder_.getUnitAttr()); builder_.create<mlir::ConstantOp>(UnknownLoc(), builder_.getUnitAttr());
// Import nodes in the graph. // Import nodes in the graph.
for (const auto &item : graph.node()) { for (const auto &item : graph.node()) {
@ -483,8 +478,7 @@ private:
namespace onnx_mlir { namespace onnx_mlir {
void ImportFrontendModelFile(std::string model_fname, void ImportFrontendModelFile(std::string model_fname,
mlir::MLIRContext &context, mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
mlir::OwningModuleRef &module) {
onnx::ModelProto model; onnx::ModelProto model;
std::fstream input(model_fname, std::ios::in | std::ios::binary); std::fstream input(model_fname, std::ios::in | std::ios::binary);

View File

@ -36,8 +36,7 @@ namespace onnx_mlir {
* @return MLIR::module generated for the ONNX model. * @return MLIR::module generated for the ONNX model.
*/ */
void ImportFrontendModelFile(std::string model_fname, void ImportFrontendModelFile(std::string model_fname,
mlir::MLIRContext &context, mlir::MLIRContext &context, mlir::OwningModuleRef &module);
mlir::OwningModuleRef &module);
/*! /*!
* TODO: Import models into other extension dialects that cover the * TODO: Import models into other extension dialects that cover the

View File

@ -1,4 +1,5 @@
//====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering --------===// //====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering
//--------===//
// //
// Copyright 2019 The IBM Research Authors. // Copyright 2019 The IBM Research Authors.
// //
@ -21,10 +22,9 @@ class ONNXEntryPointLowering : public OpRewritePattern<ONNXEntryPointOp> {
public: public:
using OpRewritePattern<ONNXEntryPointOp>::OpRewritePattern; using OpRewritePattern<ONNXEntryPointOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ONNXEntryPointOp op, LogicalResult matchAndRewrite(
PatternRewriter &rewriter) const override { ONNXEntryPointOp op, PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<KrnlEntryPointOp>( rewriter.replaceOpWithNewOp<KrnlEntryPointOp>(op,
op,
op.getAttrOfType<SymbolRefAttr>( op.getAttrOfType<SymbolRefAttr>(
ONNXEntryPointOp::getEntryPointFuncAttrName()), ONNXEntryPointOp::getEntryPointFuncAttrName()),
op.getAttrOfType<IntegerAttr>(ONNXEntryPointOp::getNumInputsAttrName()), op.getAttrOfType<IntegerAttr>(ONNXEntryPointOp::getNumInputsAttrName()),
@ -55,8 +55,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
// We define the specific operations, or dialects, that are legal targets for // We define the specific operations, or dialects, that are legal targets for
// this lowering. // this lowering.
target target.addLegalDialect<KrnlOpsDialect, AffineDialect, StandardOpsDialect>();
.addLegalDialect<KrnlOpsDialect, AffineDialect, StandardOpsDialect>();
// TODO: enable this once more ops are supported. // TODO: enable this once more ops are supported.
// We also define the ONNX dialect as Illegal so that the conversion will fail // We also define the ONNX dialect as Illegal so that the conversion will fail
@ -81,8 +80,8 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
// Type conversion for function signatures. // Type conversion for function signatures.
// Call MLIR FuncOp signature conversion when result type is // Call MLIR FuncOp signature conversion when result type is
// a ranked tensor. // a ranked tensor.
populateFuncOpTypeConversionPattern(patterns, &getContext(), populateFuncOpTypeConversionPattern(
tensor_to_memref_converter); patterns, &getContext(), tensor_to_memref_converter);
// Frontend operation lowering. // Frontend operation lowering.
// Math // Math
@ -119,5 +118,5 @@ std::unique_ptr<Pass> mlir::createLowerToKrnlPass() {
return std::make_unique<FrontendToKrnlLoweringPass>(); return std::make_unique<FrontendToKrnlLoweringPass>();
} }
static PassRegistration<FrontendToKrnlLoweringPass> static PassRegistration<FrontendToKrnlLoweringPass> pass(
pass("lower-frontend", "Lower frontend ops to Krnl dialect."); "lower-frontend", "Lower frontend ops to Krnl dialect.");

View File

@ -499,8 +499,7 @@ template <typename ElementwiseUnaryOp>
struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
ONNXElementwiseUnaryOpLowering(MLIRContext *ctx) ONNXElementwiseUnaryOpLowering(MLIRContext *ctx)
: ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {} : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {}
LogicalResult LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
// TODO: Check that the types are valid. // TODO: Check that the types are valid.
// An element-wise unary operation must have all operands and the result of // An element-wise unary operation must have all operands and the result of
@ -566,8 +565,7 @@ template <typename ElementwiseVariadicOp>
struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
ONNXElementwiseVariadicOpLowering(MLIRContext *ctx) ONNXElementwiseVariadicOpLowering(MLIRContext *ctx)
: ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {} : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {}
LogicalResult LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
// TODO: Check that the types are valid. // TODO: Check that the types are valid.
// An element-wise variadic operation must have all operands and the result // An element-wise variadic operation must have all operands and the result

View File

@ -1,4 +1,5 @@
//===----------------- Gemm.cpp - Lowering Gemm Op -------------------------===// //===----------------- Gemm.cpp - Lowering Gemm Op
//-------------------------===//
// //
// Copyright 2019 The IBM Research Authors. // Copyright 2019 The IBM Research Authors.
// //
@ -17,8 +18,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
ONNXGemmOpLowering(MLIRContext *ctx) ONNXGemmOpLowering(MLIRContext *ctx)
: ConversionPattern(GemmOp::getOperationName(), 1, ctx) {} : ConversionPattern(GemmOp::getOperationName(), 1, ctx) {}
LogicalResult LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc(); auto loc = op->getLoc();
bool hasBias = !op->getOperand(2).getType().isa<NoneType>(); bool hasBias = !op->getOperand(2).getType().isa<NoneType>();
@ -32,11 +32,9 @@ struct ONNXGemmOpLowering : public ConversionPattern {
auto memRefType = convertToMemRefType(*op->result_type_begin()); auto memRefType = convertToMemRefType(*op->result_type_begin());
auto alphaAttr = auto alphaAttr = FloatAttr::get(memRefType.getElementType(),
FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat()); llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat());
auto betaAttr = auto betaAttr = FloatAttr::get(memRefType.getElementType(),
FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<GemmOp>(op).beta().convertToFloat()); llvm::dyn_cast<GemmOp>(op).beta().convertToFloat());
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr); auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto beta = rewriter.create<ConstantOp>(loc, betaAttr); auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
@ -101,8 +99,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
optimizedReductionLoops.reserve(1); optimizedReductionLoops.reserve(1);
reductionLoops.push_back(originalLoops[2]); reductionLoops.push_back(originalLoops[2]);
optimizedReductionLoops.push_back(optimizedLoops[2]); optimizedReductionLoops.push_back(optimizedLoops[2]);
KrnlIterateOperandPack reductionPack(rewriter, reductionLoops, KrnlIterateOperandPack reductionPack(
optimizedReductionLoops); rewriter, reductionLoops, optimizedReductionLoops);
// Induction variable for the reduction dimension // Induction variable for the reduction dimension
// Try to find and use a static value from A or B first. // Try to find and use a static value from A or B first.
// If it failed then use a dynamic value. // If it failed then use a dynamic value.
@ -167,8 +165,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs); auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB); auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
if (hasBias) { if (hasBias) {
auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C, auto loopCIVs = getLoopIVsForBroadcasting(
broadcastedDimInfo); loc, rewriter, loopMNIVs, C, broadcastedDimInfo);
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs); auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
auto betaC = rewriter.create<MulFOp>(loc, beta, loadedC); auto betaC = rewriter.create<MulFOp>(loc, beta, loadedC);
auto Y = rewriter.create<AddFOp>(loc, alphaAB, betaC); auto Y = rewriter.create<AddFOp>(loc, alphaAB, betaC);
@ -214,7 +212,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
} }
}; };
void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns, void populateLoweringONNXGemmOpPattern(
MLIRContext *ctx) { OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx); patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
} }

View File

@ -16,8 +16,7 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
ONNXMatMulOpLowering(MLIRContext *ctx) ONNXMatMulOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {} : ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {}
LogicalResult LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc(); auto loc = op->getLoc();
@ -119,8 +118,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
// Define loops for batch dimensions. // Define loops for batch dimensions.
std::vector<Value> originalLoops; std::vector<Value> originalLoops;
std::vector<Value> optimizedLoops; std::vector<Value> optimizedLoops;
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, Block *optimizationBlock = defineLoops(
optimizedLoops, memRefShape.size()); rewriter, loc, originalLoops, optimizedLoops, memRefShape.size());
// Outer KrnlIterateOp // Outer KrnlIterateOp
SmallVector<Value, 4> loopBatchIVs; SmallVector<Value, 4> loopBatchIVs;
@ -139,8 +138,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
outerLoops.push_back(originalLoops[i]); outerLoops.push_back(originalLoops[i]);
optimizedOuterLoops.push_back(optimizedLoops[i]); optimizedOuterLoops.push_back(optimizedLoops[i]);
} }
KrnlIterateOperandPack outerPack(rewriter, outerLoops, KrnlIterateOperandPack outerPack(
optimizedOuterLoops); rewriter, outerLoops, optimizedOuterLoops);
for (int i = 0; i < batchAxes.size(); ++i) { for (int i = 0; i < batchAxes.size(); ++i) {
addDimensionToPack(rewriter, loc, outerPack, alloc, i); addDimensionToPack(rewriter, loc, outerPack, alloc, i);
} }
@ -176,11 +175,11 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
optimizedMatmulLoops.emplace_back( optimizedMatmulLoops.emplace_back(
optimizedLoops[memRefShape.size() - i]); optimizedLoops[memRefShape.size() - i]);
} }
KrnlIterateOperandPack matmulPack(rewriter, matmulLoops, KrnlIterateOperandPack matmulPack(
optimizedMatmulLoops); rewriter, matmulLoops, optimizedMatmulLoops);
for (int i = 2; i > 0; --i) { for (int i = 2; i > 0; --i) {
addDimensionToPack(rewriter, loc, matmulPack, alloc, addDimensionToPack(
memRefShape.size() - i); rewriter, loc, matmulPack, alloc, memRefShape.size() - i);
} }
matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack); matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack);
} else { } else {
@ -190,10 +189,10 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
matmulLoops.emplace_back(originalLoops[memRefShape.size() - 1]); matmulLoops.emplace_back(originalLoops[memRefShape.size() - 1]);
optimizedMatmulLoops.emplace_back( optimizedMatmulLoops.emplace_back(
optimizedLoops[memRefShape.size() - 1]); optimizedLoops[memRefShape.size() - 1]);
KrnlIterateOperandPack matmulPack(rewriter, matmulLoops, KrnlIterateOperandPack matmulPack(
optimizedMatmulLoops); rewriter, matmulLoops, optimizedMatmulLoops);
addDimensionToPack(rewriter, loc, matmulPack, alloc, addDimensionToPack(
memRefShape.size() - 1); rewriter, loc, matmulPack, alloc, memRefShape.size() - 1);
matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack); matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack);
} }
@ -230,8 +229,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
std::vector<Value> optimizedReduceLoops; std::vector<Value> optimizedReduceLoops;
Block *optimizationReduceBlock = Block *optimizationReduceBlock =
defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1); defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1);
KrnlIterateOperandPack reducePack(rewriter, reduceLoops, KrnlIterateOperandPack reducePack(
optimizedReduceLoops); rewriter, reduceLoops, optimizedReduceLoops);
addDimensionToPack(rewriter, loc, reducePack, A, AShape.size() - 1); addDimensionToPack(rewriter, loc, reducePack, A, AShape.size() - 1);
auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack); auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack);
@ -292,8 +291,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
std::vector<Value> optimizedReduceLoops; std::vector<Value> optimizedReduceLoops;
Block *optimizationReduceBlock = Block *optimizationReduceBlock =
defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1); defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1);
KrnlIterateOperandPack reducePack(rewriter, reduceLoops, KrnlIterateOperandPack reducePack(
optimizedReduceLoops); rewriter, reduceLoops, optimizedReduceLoops);
addDimensionToPack(rewriter, loc, reducePack, A, 0); addDimensionToPack(rewriter, loc, reducePack, A, 0);
auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack); auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack);

View File

@ -102,8 +102,7 @@ struct ONNXReductionOpLowering : public ConversionPattern {
ONNXReductionOpLowering(MLIRContext *ctx) ONNXReductionOpLowering(MLIRContext *ctx)
: ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {} : ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {}
LogicalResult LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
/* /*
* Condition: reduction function must be associative and commutative. * Condition: reduction function must be associative and commutative.

View File

@ -1,4 +1,5 @@
//===--------------- Conv.cpp - Lowering Convolution Op --------------------===// //===--------------- Conv.cpp - Lowering Convolution Op
//--------------------===//
// //
// Copyright 2019 The IBM Research Authors. // Copyright 2019 The IBM Research Authors.
// //
@ -175,14 +176,12 @@ struct ONNXConvOpLowering : public ConversionPattern {
// Emit the bias, if needed. // Emit the bias, if needed.
if (hasBias) { if (hasBias) {
auto loadResult = auto loadResult = rewriter.create<LoadOp>(loc, alloc, resultIndices);
rewriter.create<LoadOp>(loc, alloc, resultIndices);
SmallVector<Value, 4> biasIndices; SmallVector<Value, 4> biasIndices;
biasIndices.emplace_back(kernel); biasIndices.emplace_back(kernel);
auto loadBias = auto loadBias = rewriter.create<LoadOp>(loc, biasOperand, kernel);
rewriter.create<LoadOp>(loc, biasOperand, kernel); auto resultWithBias =
auto resultWithBias = rewriter.create<MulFOp>( rewriter.create<MulFOp>(loc, loadResult, loadBias);
loc, loadResult, loadBias);
// Store initializer value into output location. // Store initializer value into output location.
rewriter.create<StoreOp>(loc, resultWithBias, alloc, resultIndices); rewriter.create<StoreOp>(loc, resultWithBias, alloc, resultIndices);
} }

View File

@ -459,7 +459,8 @@ struct ONNXPoolOpLowering : public ConversionPattern {
poolDimMap = AffineMap::get(1, 5, dimExpr, rewriter.getContext()); poolDimMap = AffineMap::get(1, 5, dimExpr, rewriter.getContext());
// poolStartMap and poolEndMap // poolStartMap and poolEndMap
poolStartMap = AffineMap::get(1, 5, {start1, start2}, rewriter.getContext()); poolStartMap =
AffineMap::get(1, 5, {start1, start2}, rewriter.getContext());
poolEndMap = AffineMap::get(1, 5, {end1, end2}, rewriter.getContext()); poolEndMap = AffineMap::get(1, 5, {end1, end2}, rewriter.getContext());
} }

View File

@ -36,9 +36,7 @@ MemRefType convertToMemRefType(Type type) {
/// Insert an allocation and deallocation for the given MemRefType. /// Insert an allocation and deallocation for the given MemRefType.
Value insertAllocAndDealloc(MemRefType type, Location loc, Value insertAllocAndDealloc(MemRefType type, Location loc,
PatternRewriter &rewriter, PatternRewriter &rewriter, bool insertDealloc, ArrayRef<Value> operands) {
bool insertDealloc,
ArrayRef<Value> operands) {
// Put together alloc operands for any dynamic dimensions of the memref. // Put together alloc operands for any dynamic dimensions of the memref.
AllocOp alloc; AllocOp alloc;
if (!operands.empty()) { if (!operands.empty()) {
@ -64,10 +62,10 @@ Value insertAllocAndDealloc(MemRefType type, Location loc,
auto operandDim = auto operandDim =
rewriter.create<DimOp>(loc, operands[i], operandDimIdx); rewriter.create<DimOp>(loc, operands[i], operandDimIdx);
if (maxDim) { if (maxDim) {
auto maxCondition = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, auto maxCondition = rewriter.create<CmpIOp>(
operandDim, maxDim); loc, CmpIPredicate::sgt, operandDim, maxDim);
maxDim = rewriter.create<SelectOp>(loc, maxCondition, operandDim, maxDim = rewriter.create<SelectOp>(
maxDim); loc, maxCondition, operandDim, maxDim);
} else { } else {
maxDim = operandDim; maxDim = operandDim;
} }
@ -122,8 +120,8 @@ bool checkInsertDealloc(Operation *currentOp, int resultIndex) {
// Create a mapping from result type's dimensions to input type's dimensions, // Create a mapping from result type's dimensions to input type's dimensions,
// given that the result type is the result of a reduction op over the input // given that the result type is the result of a reduction op over the input
// type. // type.
std::map<int64_t, int64_t> std::map<int64_t, int64_t> getReductionMapping(
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) { MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
std::map<int64_t, int64_t> OutInDimMap; std::map<int64_t, int64_t> OutInDimMap;
int64_t rank = inputTy.getRank(); int64_t rank = inputTy.getRank();
@ -152,9 +150,8 @@ getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
// Add bounds associated with the op operand to the KRNL iteration pack. // Add bounds associated with the op operand to the KRNL iteration pack.
// Dynamic dimenions are supported. // Dynamic dimenions are supported.
void addDimensionToPack(ConversionPatternRewriter &rewriter, void addDimensionToPack(ConversionPatternRewriter &rewriter, Location loc,
Location loc, KrnlIterateOperandPack &pack, KrnlIterateOperandPack &pack, Value operand, int index) {
Value operand, int index) {
auto shape = operand.getType().cast<MemRefType>().getShape(); auto shape = operand.getType().cast<MemRefType>().getShape();
if (shape[index] < 0) { if (shape[index] < 0) {
pack.pushConstantBound(0); pack.pushConstantBound(0);
@ -168,10 +165,9 @@ void addDimensionToPack(ConversionPatternRewriter &rewriter,
// Function that defines the KRNL dialect loops and their respective // Function that defines the KRNL dialect loops and their respective
// optimized version. // optimized version.
KrnlOptimizeLoopsOp KrnlOptimizeLoopsOp emitOptimizedLoops(ConversionPatternRewriter &rewriter,
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc, Location loc, std::vector<Value> &loops, std::vector<Value> &optimizedLoops,
std::vector<Value> &loops, int64_t numLoops) {
std::vector<Value> &optimizedLoops, int64_t numLoops) {
// Define loops. // Define loops.
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops); auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops);
loops.reserve(numLoops); loops.reserve(numLoops);
@ -190,8 +186,7 @@ emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
// Function that emits the loops and their optimized version. // Function that emits the loops and their optimized version.
// The function returns a reference to the inner optimization block. // The function returns a reference to the inner optimization block.
Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc, Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
std::vector<Value> &loops, std::vector<Value> &loops, std::vector<Value> &optimizedLoops,
std::vector<Value> &optimizedLoops,
int64_t numLoops) { int64_t numLoops) {
KrnlOptimizeLoopsOp optimizedLoopsOp = KrnlOptimizeLoopsOp optimizedLoopsOp =
emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops); emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops);
@ -201,10 +196,9 @@ Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
// Function which emits a basic set of loops and optimized loops // Function which emits a basic set of loops and optimized loops
// for a given operation argument. A reference to the loop optimization // for a given operation argument. A reference to the loop optimization
// block is returned in the last argument of the function. // block is returned in the last argument of the function.
void emitKrnlLoopsAndIterationForOperand( void emitKrnlLoopsAndIterationForOperand(ConversionPatternRewriter &rewriter,
ConversionPatternRewriter &rewriter, Location loc, Value operand, Location loc, Value operand, std::vector<Value> &originalLoops,
std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp, KrnlOptimizeLoopsOp &optimizedLoopsOp, KrnlIterateOp &iterateOp) {
KrnlIterateOp &iterateOp) {
// Operand shape. // Operand shape.
auto shape = operand.getType().cast<MemRefType>().getShape(); auto shape = operand.getType().cast<MemRefType>().getShape();
@ -240,9 +234,9 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
// Get run-time dimension information for unknown dimensions used for // Get run-time dimension information for unknown dimensions used for
// broadcasting. // broadcasting.
std::map<int, std::map<int, Value>> std::map<int, std::map<int, Value>> getBroadcastedDimInfo(Location loc,
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter, MemRefType memRefType,
MemRefType memRefType, ArrayRef<Value> operands) { ArrayRef<Value> operands) {
auto memRefShape = memRefType.getShape(); auto memRefShape = memRefType.getShape();
int64_t rank = memRefShape.size(); int64_t rank = memRefShape.size();
// For unknown dimensions, we need to get dimension values at runtime in // For unknown dimensions, we need to get dimension values at runtime in
@ -286,9 +280,8 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
// Extract induction variables that are used for broadcasting values of a // Extract induction variables that are used for broadcasting values of a
// given operand. // given operand.
std::vector<Value> std::vector<Value> getLoopIVsForBroadcasting(Location loc,
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter, ArrayRef<Value> loopIVs, Value operand,
ArrayRef<Value> loopIVs, Value operand,
std::map<int, Value> broadcastedDims) { std::map<int, Value> broadcastedDims) {
// `operand` must has a ranked type. This should have been checked by the // `operand` must has a ranked type. This should have been checked by the
// shape inference pass. // shape inference pass.
@ -310,8 +303,8 @@ getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
// If its value is 1, it is broadcasted dimension. // If its value is 1, it is broadcasted dimension.
// Otherwise, non-broadcasted dimension. // Otherwise, non-broadcasted dimension.
auto zero = rewriter.create<ConstantIndexOp>(loc, 0); auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], zero, auto idx = rewriter.create<SelectOp>(
loopIVs[loopIdx]); loc, broadcastedDims[dimIdx], zero, loopIVs[loopIdx]);
newLoopIVs.insert(newLoopIVs.begin(), idx); newLoopIVs.insert(newLoopIVs.begin(), idx);
} else { } else {
// Non-broadcasted dimension // Non-broadcasted dimension

View File

@ -38,8 +38,7 @@ struct ONNXConstantOpLowering : public ConversionPattern {
numElements *= shape[i]; numElements *= shape[i];
// Emit the constant global in Krnl dialect. // Emit the constant global in Krnl dialect.
auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc, auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc, memRefType,
memRefType,
rewriter.getI64ArrayAttr(shape), rewriter.getI64ArrayAttr(shape),
rewriter.getStringAttr("constant_" + std::to_string(constantID)), rewriter.getStringAttr("constant_" + std::to_string(constantID)),
constantOp.value().getValue()); constantOp.value().getValue());

View File

@ -39,7 +39,6 @@ MLONNXOpsDialect::MLONNXOpsDialect(mlir::MLIRContext *ctx)
>(); >();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TableGen'd op method definitions // TableGen'd op method definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -19,8 +19,8 @@
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "src/Interface/ShapeInferenceInterface.hpp"
#include "src/Interface/PromotableConstOperandsOpInterface.hpp" #include "src/Interface/PromotableConstOperandsOpInterface.hpp"
#include "src/Interface/ShapeInferenceInterface.hpp"
namespace mlir { namespace mlir {

View File

@ -19,8 +19,8 @@
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "src/Interface/ShapeInferenceInterface.hpp"
#include "src/Interface/PromotableConstOperandsOpInterface.hpp" #include "src/Interface/PromotableConstOperandsOpInterface.hpp"
#include "src/Interface/ShapeInferenceInterface.hpp"
namespace mlir { namespace mlir {

View File

@ -10,7 +10,6 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "src/Interface/PromotableConstOperandsOpInterface.hpp" #include "src/Interface/PromotableConstOperandsOpInterface.hpp"
using namespace mlir; using namespace mlir;
@ -20,4 +19,3 @@ using namespace mlir;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "src/Interface/PromotableConstOperandsOpInterface.cpp.inc" #include "src/Interface/PromotableConstOperandsOpInterface.cpp.inc"

View File

@ -46,10 +46,10 @@ void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
void EmitLLVMBitCode( void EmitLLVMBitCode(
const mlir::OwningModuleRef &module, string outputFilename) { const mlir::OwningModuleRef &module, string outputFilename) {
error_code error; error_code error;
llvm::raw_fd_ostream moduleBitcodeStream(outputFilename, error, llvm::raw_fd_ostream moduleBitcodeStream(
llvm::sys::fs::F_None); outputFilename, error, llvm::sys::fs::F_None);
llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module), llvm::WriteBitcodeToFile(
moduleBitcodeStream); *mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream);
moduleBitcodeStream.flush(); moduleBitcodeStream.flush();
} }
@ -99,7 +99,6 @@ void processInputFile(string inputFilename, EmissionTargetType emissionTarget,
assert(inputIsONNX != inputIsMLIR && assert(inputIsONNX != inputIsMLIR &&
"Either ONNX model or MLIR file needs to be provided."); "Either ONNX model or MLIR file needs to be provided.");
if (inputIsONNX) { if (inputIsONNX) {
ImportFrontendModelFile(inputFilename, context, module); ImportFrontendModelFile(inputFilename, context, module);
} else { } else {

View File

@ -28,9 +28,9 @@
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/InitAllDialects.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h" #include "mlir/IR/Module.h"
#include "mlir/InitAllDialects.h"
#include "mlir/Parser.h" #include "mlir/Parser.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassManager.h"
@ -66,8 +66,7 @@ void processInputFile(std::string inputFilename,
mlir::OwningModuleRef &module); mlir::OwningModuleRef &module);
void outputCode( void outputCode(
mlir::OwningModuleRef &module, std::string filename, mlir::OwningModuleRef &module, std::string filename, std::string extension);
std::string extension);
void emitOutputFiles(std::string outputBaseName, void emitOutputFiles(std::string outputBaseName,
EmissionTargetType emissionTarget, mlir::MLIRContext &context, EmissionTargetType emissionTarget, mlir::MLIRContext &context,

View File

@ -33,8 +33,8 @@ DynMemRef *getDynMemRef(OrderedDynMemRefDict *tensorDict, int idx) {
return tensorDict->tensorDict[tensorDict->orderedNames[idx]]; return tensorDict->tensorDict[tensorDict->orderedNames[idx]];
} }
void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx, void setDynMemRef(
DynMemRef *tensor) { OrderedDynMemRefDict *tensorDict, int idx, DynMemRef *tensor) {
if (tensorDict->orderedNames.size() <= idx) if (tensorDict->orderedNames.size() <= idx)
tensorDict->orderedNames.resize(idx + 1); tensorDict->orderedNames.resize(idx + 1);

View File

@ -39,7 +39,6 @@ typedef struct _OrderedDynMemRefDict OrderedDynMemRefDict;
extern "C" { extern "C" {
#endif #endif
// Get number of dynamic memrefs in OrderedDynMemRefDict dict. // Get number of dynamic memrefs in OrderedDynMemRefDict dict.
int numDynMemRefs(OrderedDynMemRefDict *dict); int numDynMemRefs(OrderedDynMemRefDict *dict);
@ -53,8 +52,8 @@ DynMemRef *createDynMemRef(int rank);
DynMemRef *getDynMemRef(OrderedDynMemRefDict *orderedDict, int i); DynMemRef *getDynMemRef(OrderedDynMemRefDict *orderedDict, int i);
// Set the i-th dynmemref in orderedDict to be dynMemRef. // Set the i-th dynmemref in orderedDict to be dynMemRef.
void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx, void setDynMemRef(
DynMemRef *dynMemRef); OrderedDynMemRefDict *tensorDict, int idx, DynMemRef *dynMemRef);
// Get data pointer from dynMemRef. // Get data pointer from dynMemRef.
void *getData(DynMemRef *dynMemRef); void *getData(DynMemRef *dynMemRef);

View File

@ -1,14 +1,14 @@
#include "Runtime.hpp" #include "Runtime.hpp"
ExecutionSession::ExecutionSession(std::string sharedLibPath, ExecutionSession::ExecutionSession(
std::string entryPointName) { std::string sharedLibPath, std::string entryPointName) {
_sharedLibraryHandle = dlopen(sharedLibPath.c_str(), RTLD_LAZY); _sharedLibraryHandle = dlopen(sharedLibPath.c_str(), RTLD_LAZY);
_entryPointFunc = _entryPointFunc =
(entryPointFuncType)dlsym(_sharedLibraryHandle, entryPointName.c_str()); (entryPointFuncType)dlsym(_sharedLibraryHandle, entryPointName.c_str());
} }
std::vector<py::array> std::vector<py::array> ExecutionSession::run(
ExecutionSession::run(std::vector<py::array> inputsPyArray) { std::vector<py::array> inputsPyArray) {
assert(_entryPointFunc && "entry point not loaded"); assert(_entryPointFunc && "entry point not loaded");
auto *wrappedInput = createOrderedDynMemRefDict(); auto *wrappedInput = createOrderedDynMemRefDict();
int inputIdx = 0; int inputIdx = 0;
@ -40,8 +40,8 @@ ExecutionSession::run(std::vector<py::array> inputsPyArray) {
auto *wrappedOutput = _entryPointFunc(wrappedInput); auto *wrappedOutput = _entryPointFunc(wrappedInput);
for (int i = 0; i < numDynMemRefs(wrappedOutput); i++) { for (int i = 0; i < numDynMemRefs(wrappedOutput); i++) {
auto *dynMemRef = getDynMemRef(wrappedOutput, i); auto *dynMemRef = getDynMemRef(wrappedOutput, i);
auto shape = std::vector<int64_t>(dynMemRef->sizes, auto shape = std::vector<int64_t>(
dynMemRef->sizes + dynMemRef->rank); dynMemRef->sizes, dynMemRef->sizes + dynMemRef->rank);
outputPyArrays.emplace_back( outputPyArrays.emplace_back(
py::array(py::dtype("float32"), shape, dynMemRef->data)); py::array(py::dtype("float32"), shape, dynMemRef->data));
} }

View File

@ -144,9 +144,9 @@ public:
assert(krnlGlobalOp.value().hasValue() && assert(krnlGlobalOp.value().hasValue() &&
"Krnl Global must always have a value"); "Krnl Global must always have a value");
global = rewriter.create<LLVM::GlobalOp>(loc, global = rewriter.create<LLVM::GlobalOp>(loc, llvmGlobalType,
llvmGlobalType, /*isConstant=*/true, /*isConstant=*/true, LLVM::Linkage::Internal, name,
LLVM::Linkage::Internal, name, krnlGlobalOp.value().getValue()); krnlGlobalOp.value().getValue());
} }
// Some frequently used types. // Some frequently used types.

View File

@ -12,8 +12,8 @@
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include <numeric>
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
#include <numeric>
using namespace mlir; using namespace mlir;

View File

@ -27,7 +27,8 @@ namespace {
/// Include the patterns defined in the Declarative Rewrite framework. /// Include the patterns defined in the Declarative Rewrite framework.
#include "src/Transform/ONNX/ONNXDecompose.inc" #include "src/Transform/ONNX/ONNXDecompose.inc"
struct DecomposeONNXToONNXPass : public PassWrapper<DecomposeONNXToONNXPass, FunctionPass> { struct DecomposeONNXToONNXPass
: public PassWrapper<DecomposeONNXToONNXPass, FunctionPass> {
void runOnFunction() final; void runOnFunction() final;
}; };
} // end anonymous namespace. } // end anonymous namespace.

View File

@ -9,10 +9,10 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "mlir/IR/StandardTypes.h"
#include "src/Interface/ShapeInferenceInterface.hpp" #include "src/Interface/ShapeInferenceInterface.hpp"
#include "src/Pass/Passes.hpp" #include "src/Pass/Passes.hpp"
@ -25,7 +25,8 @@ namespace {
* candidate operations and propagating the shape information until the list * candidate operations and propagating the shape information until the list
* of operations is empty [credit MLIR authors]. * of operations is empty [credit MLIR authors].
*/ */
class ShapeInferencePass : public mlir::PassWrapper<ShapeInferencePass, mlir::FunctionPass> { class ShapeInferencePass
: public mlir::PassWrapper<ShapeInferencePass, mlir::FunctionPass> {
public: public:
void runOnFunction() override { void runOnFunction() override {
auto f = getFunction(); auto f = getFunction();
@ -63,8 +64,7 @@ public:
if (auto terminator_op = f.getBody().back().getTerminator()) { if (auto terminator_op = f.getBody().back().getTerminator()) {
auto results = terminator_op->getOperandTypes(); auto results = terminator_op->getOperandTypes();
f.setType(FunctionType::get( f.setType(FunctionType::get(f.getType().getInputs(),
f.getType().getInputs(),
std::vector<Type>(results.begin(), results.end()), f.getContext())); std::vector<Type>(results.begin(), results.end()), f.getContext()));
} }
} }
@ -146,5 +146,5 @@ std::unique_ptr<mlir::Pass> mlir::createShapeInferencePass() {
return std::make_unique<ShapeInferencePass>(); return std::make_unique<ShapeInferencePass>();
} }
static PassRegistration<ShapeInferencePass> static PassRegistration<ShapeInferencePass> pass(
pass("shape-inference", "Shape inference for frontend dialects."); "shape-inference", "Shape inference for frontend dialects.");

View File

@ -14,10 +14,10 @@ using namespace onnx_mlir;
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
registerDialects(); registerDialects();
llvm::cl::OptionCategory OnnxMlirOptions("ONNX MLIR Options", llvm::cl::OptionCategory OnnxMlirOptions(
"These are frontend options."); "ONNX MLIR Options", "These are frontend options.");
llvm::cl::opt<string> inputFilename( llvm::cl::opt<string> inputFilename(llvm::cl::Positional,
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"), llvm::cl::desc("<input file>"), llvm::cl::init("-"),
llvm::cl::cat(OnnxMlirOptions)); llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<EmissionTargetType> emissionTarget( llvm::cl::opt<EmissionTargetType> emissionTarget(
@ -26,18 +26,18 @@ int main(int argc, char *argv[]) {
clEnumVal(EmitONNXBasic, clEnumVal(EmitONNXBasic,
"Ingest ONNX and emit the basic ONNX operations without" "Ingest ONNX and emit the basic ONNX operations without"
"inferred shapes."), "inferred shapes."),
clEnumVal(EmitONNXIR, clEnumVal(
"Ingest ONNX and emit corresponding ONNX dialect."), EmitONNXIR, "Ingest ONNX and emit corresponding ONNX dialect."),
clEnumVal(EmitMLIR, clEnumVal(
"Lower model to MLIR built-in transformation dialect."), EmitMLIR, "Lower model to MLIR built-in transformation dialect."),
clEnumVal(EmitLLVMIR, "Lower model to LLVM IR (LLVM dialect)."), clEnumVal(EmitLLVMIR, "Lower model to LLVM IR (LLVM dialect)."),
clEnumVal(EmitLLVMBC, "Lower model to LLVM IR and emit (to file) " clEnumVal(EmitLLVMBC, "Lower model to LLVM IR and emit (to file) "
"LLVM bitcode for model.")), "LLVM bitcode for model.")),
llvm::cl::init(EmitLLVMBC), llvm::cl::cat(OnnxMlirOptions)); llvm::cl::init(EmitLLVMBC), llvm::cl::cat(OnnxMlirOptions));
llvm::cl::HideUnrelatedOptions(OnnxMlirOptions); llvm::cl::HideUnrelatedOptions(OnnxMlirOptions);
llvm::cl::ParseCommandLineOptions(argc, argv, llvm::cl::ParseCommandLineOptions(
"ONNX MLIR modular optimizer driver\n"); argc, argv, "ONNX MLIR modular optimizer driver\n");
mlir::MLIRContext context; mlir::MLIRContext context;
mlir::OwningModuleRef module; mlir::OwningModuleRef module;