[NFC] Set up clang-format Github Action (#119)
* Run clang-format on all source code. * Add Clang-Format Github Action. * Apply patch produced by Clang-Format Bot. * nit. Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
24343177b8
commit
7f2bffb27d
|
@ -0,0 +1,40 @@
|
|||
# This is a basic workflow to help you get started with Actions
|
||||
|
||||
name: Clang-Format Bot
|
||||
|
||||
# Controls when the action will run. Triggers the workflow on push or pull request
|
||||
# events but only for the master branch
|
||||
on:
|
||||
push:
|
||||
branches: [ master ]
|
||||
pull_request:
|
||||
branches: [ master ]
|
||||
|
||||
# A workflow run is made up of one or more jobs that can run sequentially or in parallel
|
||||
jobs:
|
||||
# This workflow contains a single job called "build"
|
||||
build:
|
||||
# The type of runner that the job will run on
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
# Steps represent a sequence of tasks that will be executed as part of the job
|
||||
steps:
|
||||
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
|
||||
- uses: actions/checkout@v2
|
||||
- name: clang-format lint
|
||||
uses: DoozyX/clang-format-lint-action@v0.5
|
||||
with:
|
||||
# Source folder to check formatting
|
||||
source: ./src
|
||||
# Version of clang-format
|
||||
clangFormatVersion: 9 # optional, default is 9
|
||||
|
||||
# Runs a single command using the runners shell
|
||||
- name: Run a one-line script
|
||||
run: echo Hello, world!
|
||||
|
||||
# Runs a set of commands using the runners shell
|
||||
- name: Run a multi-line script
|
||||
run: |
|
||||
echo Add other actions to build,
|
||||
echo test, and deploy your project.
|
|
@ -12,8 +12,8 @@
|
|||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void replaceAll(std::string &str, const std::string &from,
|
||||
const std::string &to) {
|
||||
void replaceAll(
|
||||
std::string &str, const std::string &from, const std::string &to) {
|
||||
if (from.empty())
|
||||
return;
|
||||
size_t start_pos = 0;
|
||||
|
@ -121,7 +121,6 @@ void InitializedTensorMapping::AddMapping(
|
|||
nameToInitializedTensor.emplace(name, tensor);
|
||||
}
|
||||
|
||||
|
||||
bool InitializedTensorMapping::ContainKey(std::string name) {
|
||||
return nameToInitializedTensor.count(name) != 0;
|
||||
}
|
||||
|
@ -132,8 +131,8 @@ mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor(
|
|||
onnx::TensorProto initializer = GetInitializedTensor(name);
|
||||
|
||||
// Tensor dimensions.
|
||||
llvm::ArrayRef<int64_t> tensorDims(initializer.dims().data(),
|
||||
initializer.dims().size());
|
||||
llvm::ArrayRef<int64_t> tensorDims(
|
||||
initializer.dims().data(), initializer.dims().size());
|
||||
|
||||
// Emit ConstantOp and record the mapping between the input and
|
||||
// the constant value.
|
||||
|
@ -142,33 +141,32 @@ mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor(
|
|||
mlir::Type elementType;
|
||||
mlir::ShapedType tensorType;
|
||||
switch (initializer.data_type()) {
|
||||
case (onnx::TensorProto::FLOAT): {
|
||||
const auto& arrayAttrInitializer =
|
||||
CreateArrayAttribute<float>(initializer);
|
||||
elementType = builder.getF32Type();
|
||||
tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
|
||||
constantDenseAttribute = mlir::DenseElementsAttr::get(
|
||||
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
|
||||
break;
|
||||
}
|
||||
case (onnx::TensorProto::INT32): {
|
||||
const auto& arrayAttrInitializer =
|
||||
CreateArrayAttribute<int32_t>(initializer);
|
||||
elementType = builder.getIntegerType(32);
|
||||
tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
|
||||
constantDenseAttribute = mlir::DenseElementsAttr::get(
|
||||
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
|
||||
break;
|
||||
}
|
||||
case (onnx::TensorProto::INT64): {
|
||||
const auto& arrayAttrInitializer =
|
||||
CreateArrayAttribute<int64_t>(initializer);
|
||||
elementType = builder.getIntegerType(64);
|
||||
tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
|
||||
constantDenseAttribute = mlir::DenseElementsAttr::get(
|
||||
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
|
||||
break;
|
||||
}
|
||||
case (onnx::TensorProto::FLOAT): {
|
||||
const auto &arrayAttrInitializer = CreateArrayAttribute<float>(initializer);
|
||||
elementType = builder.getF32Type();
|
||||
tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
|
||||
constantDenseAttribute = mlir::DenseElementsAttr::get(
|
||||
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
|
||||
break;
|
||||
}
|
||||
case (onnx::TensorProto::INT32): {
|
||||
const auto &arrayAttrInitializer =
|
||||
CreateArrayAttribute<int32_t>(initializer);
|
||||
elementType = builder.getIntegerType(32);
|
||||
tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
|
||||
constantDenseAttribute = mlir::DenseElementsAttr::get(
|
||||
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
|
||||
break;
|
||||
}
|
||||
case (onnx::TensorProto::INT64): {
|
||||
const auto &arrayAttrInitializer =
|
||||
CreateArrayAttribute<int64_t>(initializer);
|
||||
elementType = builder.getIntegerType(64);
|
||||
tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
|
||||
constantDenseAttribute = mlir::DenseElementsAttr::get(
|
||||
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Create ConstantOp for dense array.
|
||||
|
|
|
@ -20,8 +20,8 @@
|
|||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
@ -37,11 +37,12 @@
|
|||
#endif
|
||||
|
||||
#include "onnx/onnx_pb.h"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void replaceAll(std::string &str, const std::string &from,
|
||||
const std::string &to);
|
||||
void replaceAll(
|
||||
std::string &str, const std::string &from, const std::string &to);
|
||||
|
||||
std::string legalize_name(std::string name);
|
||||
|
||||
|
@ -86,14 +87,14 @@ struct InitializedTensorMapping {
|
|||
// This will allow the propagation of shape information passed in as an
|
||||
// argument to operations such as Reshape and will enable other
|
||||
// optimizations such as constant folding.
|
||||
mlir::Value EmitInitializerForInputTensor(mlir::Location loc,
|
||||
mlir::OpBuilder &builder, std::string name);
|
||||
mlir::Value EmitInitializerForInputTensor(
|
||||
mlir::Location loc, mlir::OpBuilder &builder, std::string name);
|
||||
|
||||
// Get initialized tensor.
|
||||
onnx::TensorProto& GetInitializedTensor(std::string name) {
|
||||
assert(nameToInitializedTensor.find(name) !=
|
||||
nameToInitializedTensor.end() &&
|
||||
"Tensor initializer not found");
|
||||
onnx::TensorProto &GetInitializedTensor(std::string name) {
|
||||
assert(
|
||||
nameToInitializedTensor.find(name) != nameToInitializedTensor.end() &&
|
||||
"Tensor initializer not found");
|
||||
return nameToInitializedTensor.at(name);
|
||||
}
|
||||
|
||||
|
|
|
@ -127,8 +127,8 @@ private:
|
|||
* @param input onnx input tensor ValueInfoProto.
|
||||
* @param symbol mlir input argument.
|
||||
*/
|
||||
void ImportInputTensorSymbol(const onnx::ValueInfoProto &input,
|
||||
mlir::Value symbol) {
|
||||
void ImportInputTensorSymbol(
|
||||
const onnx::ValueInfoProto &input, mlir::Value symbol) {
|
||||
auto input_tensor_legalized_name = legalize_name(input.name());
|
||||
assert(!frontend_symbols_.ContainKey(input_tensor_legalized_name) &&
|
||||
"Found duplicate legalized input tensor names.");
|
||||
|
@ -136,8 +136,7 @@ private:
|
|||
}
|
||||
|
||||
typedef bstd::variant<int64_t, std::vector<int64_t>, float,
|
||||
std::vector<float>, std::string,
|
||||
std::vector<std::string>>
|
||||
std::vector<float>, std::string, std::vector<std::string>>
|
||||
AttrValueType;
|
||||
|
||||
struct ONNXAttrVisitor {
|
||||
|
@ -213,8 +212,8 @@ private:
|
|||
llvm_unreachable("Failed to convert attribute proto to name/value pair");
|
||||
}
|
||||
|
||||
std::vector<mlir::NamedAttribute>
|
||||
ImportNodeAttributes(const onnx::NodeProto &node) {
|
||||
std::vector<mlir::NamedAttribute> ImportNodeAttributes(
|
||||
const onnx::NodeProto &node) {
|
||||
std::vector<mlir::NamedAttribute> attributes;
|
||||
for (int i = 0; i < node.attribute_size(); ++i) {
|
||||
auto attr = node.attribute(i);
|
||||
|
@ -281,26 +280,25 @@ private:
|
|||
// TODO: Handle optional inputs.
|
||||
auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
|
||||
for (int i = 0; i < node.output().size(); i++) {
|
||||
frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
|
||||
*(op.getODSResults(i).begin()));
|
||||
frontend_symbols_.AddMapping(
|
||||
legalize_name(node.output()[i]), *(op.getODSResults(i).begin()));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void buildOperation(const onnx::NodeProto &node,
|
||||
int expectedNumOperands = -1,
|
||||
int expectedNumResults = -1) {
|
||||
void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1,
|
||||
int expectedNumResults = -1) {
|
||||
std::vector<mlir::Value> inputs;
|
||||
for (const auto &item : node.input())
|
||||
if (initializedTensors.ContainKey(legalize_name(item))) {
|
||||
inputs.push_back(initializedTensors.EmitInitializerForInputTensor(
|
||||
UnknownLoc(), builder_, legalize_name(item)));
|
||||
UnknownLoc(), builder_, legalize_name(item)));
|
||||
} else if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||
}
|
||||
|
||||
buildOutputAndOperation<T>(node, inputs, expectedNumOperands,
|
||||
expectedNumResults);
|
||||
buildOutputAndOperation<T>(
|
||||
node, inputs, expectedNumOperands, expectedNumResults);
|
||||
}
|
||||
|
||||
void ImportNodeReshape(onnx::NodeProto node, int nIn, int nOut) {
|
||||
|
@ -310,9 +308,8 @@ private:
|
|||
item = node.input()[i];
|
||||
// For the second argument, check if there exists an initializer.
|
||||
if (initializedTensors.ContainKey(legalize_name(item))) {
|
||||
inputs.push_back(
|
||||
initializedTensors.EmitInitializerForInputTensor(
|
||||
UnknownLoc(), builder_, legalize_name(item)));
|
||||
inputs.push_back(initializedTensors.EmitInitializerForInputTensor(
|
||||
UnknownLoc(), builder_, legalize_name(item)));
|
||||
} else if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||
}
|
||||
|
@ -372,7 +369,6 @@ private:
|
|||
#if INCLUDE_ONNX_ML == 1
|
||||
#include "src/Builder/MLOpBuildTable.inc"
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
/*!
|
||||
|
@ -388,8 +384,8 @@ private:
|
|||
* output tensor.
|
||||
*/
|
||||
void ImportOutputTensor(const onnx::ValueInfoProto &output,
|
||||
llvm::SmallVectorImpl<mlir::Type> &ret_types,
|
||||
llvm::SmallVectorImpl<mlir::Value> &ret_vals) {
|
||||
llvm::SmallVectorImpl<mlir::Type> &ret_types,
|
||||
llvm::SmallVectorImpl<mlir::Value> &ret_vals) {
|
||||
auto output_tensor_legalized_name = legalize_name(output.name());
|
||||
assert(frontend_symbols_.ContainKey(output_tensor_legalized_name) &&
|
||||
"Output tensor not found");
|
||||
|
@ -400,8 +396,8 @@ private:
|
|||
ret_vals.push_back(tensor_val);
|
||||
}
|
||||
|
||||
void ImportGraph(const onnx::GraphProto &graph,
|
||||
const std::string &name = "main_graph") {
|
||||
void ImportGraph(
|
||||
const onnx::GraphProto &graph, const std::string &name = "main_graph") {
|
||||
// Maintain a mapping between the parameter and its initializer.
|
||||
for (auto initializer : graph.initializer()) {
|
||||
auto name = initializer.name();
|
||||
|
@ -426,8 +422,7 @@ private:
|
|||
|
||||
// Emit the entry point operation which specifies the number of user
|
||||
// inputs and outputs.
|
||||
auto entryPoint = mlir::ONNXEntryPointOp::create(
|
||||
UnknownLoc(), mainFunc,
|
||||
auto entryPoint = mlir::ONNXEntryPointOp::create(UnknownLoc(), mainFunc,
|
||||
/*numInputs=*/graph.input().size() - graph.initializer().size(),
|
||||
/*numOutputs=*/graph.output().size());
|
||||
|
||||
|
@ -454,8 +449,8 @@ private:
|
|||
|
||||
// Create a NoneTyped constant to be used for optional operation inputs
|
||||
// which are not used.
|
||||
none_ = builder_.create<mlir::ConstantOp>(UnknownLoc(),
|
||||
builder_.getUnitAttr());
|
||||
none_ =
|
||||
builder_.create<mlir::ConstantOp>(UnknownLoc(), builder_.getUnitAttr());
|
||||
|
||||
// Import nodes in the graph.
|
||||
for (const auto &item : graph.node()) {
|
||||
|
@ -483,8 +478,7 @@ private:
|
|||
namespace onnx_mlir {
|
||||
|
||||
void ImportFrontendModelFile(std::string model_fname,
|
||||
mlir::MLIRContext &context,
|
||||
mlir::OwningModuleRef &module) {
|
||||
mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
|
||||
onnx::ModelProto model;
|
||||
std::fstream input(model_fname, std::ios::in | std::ios::binary);
|
||||
|
||||
|
|
|
@ -36,11 +36,10 @@ namespace onnx_mlir {
|
|||
* @return MLIR::module generated for the ONNX model.
|
||||
*/
|
||||
void ImportFrontendModelFile(std::string model_fname,
|
||||
mlir::MLIRContext &context,
|
||||
mlir::OwningModuleRef &module);
|
||||
mlir::MLIRContext &context, mlir::OwningModuleRef &module);
|
||||
|
||||
/*!
|
||||
* TODO: Import models into other extension dialects that cover the
|
||||
* operations specific to other frameworks such as Tensorflow or Pytorch.
|
||||
*/
|
||||
} // namespace onnx_mlir
|
||||
} // namespace onnx_mlir
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
//====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering --------===//
|
||||
//====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering
|
||||
//--------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
|
@ -21,10 +22,9 @@ class ONNXEntryPointLowering : public OpRewritePattern<ONNXEntryPointOp> {
|
|||
public:
|
||||
using OpRewritePattern<ONNXEntryPointOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXEntryPointOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<KrnlEntryPointOp>(
|
||||
op,
|
||||
LogicalResult matchAndRewrite(
|
||||
ONNXEntryPointOp op, PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<KrnlEntryPointOp>(op,
|
||||
op.getAttrOfType<SymbolRefAttr>(
|
||||
ONNXEntryPointOp::getEntryPointFuncAttrName()),
|
||||
op.getAttrOfType<IntegerAttr>(ONNXEntryPointOp::getNumInputsAttrName()),
|
||||
|
@ -55,8 +55,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
|
|||
|
||||
// We define the specific operations, or dialects, that are legal targets for
|
||||
// this lowering.
|
||||
target
|
||||
.addLegalDialect<KrnlOpsDialect, AffineDialect, StandardOpsDialect>();
|
||||
target.addLegalDialect<KrnlOpsDialect, AffineDialect, StandardOpsDialect>();
|
||||
|
||||
// TODO: enable this once more ops are supported.
|
||||
// We also define the ONNX dialect as Illegal so that the conversion will fail
|
||||
|
@ -81,8 +80,8 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
|
|||
// Type conversion for function signatures.
|
||||
// Call MLIR FuncOp signature conversion when result type is
|
||||
// a ranked tensor.
|
||||
populateFuncOpTypeConversionPattern(patterns, &getContext(),
|
||||
tensor_to_memref_converter);
|
||||
populateFuncOpTypeConversionPattern(
|
||||
patterns, &getContext(), tensor_to_memref_converter);
|
||||
|
||||
// Frontend operation lowering.
|
||||
// Math
|
||||
|
@ -119,5 +118,5 @@ std::unique_ptr<Pass> mlir::createLowerToKrnlPass() {
|
|||
return std::make_unique<FrontendToKrnlLoweringPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<FrontendToKrnlLoweringPass>
|
||||
pass("lower-frontend", "Lower frontend ops to Krnl dialect.");
|
||||
static PassRegistration<FrontendToKrnlLoweringPass> pass(
|
||||
"lower-frontend", "Lower frontend ops to Krnl dialect.");
|
||||
|
|
|
@ -499,9 +499,8 @@ template <typename ElementwiseUnaryOp>
|
|||
struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
||||
ONNXElementwiseUnaryOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {}
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// TODO: Check that the types are valid.
|
||||
// An element-wise unary operation must have all operands and the result of
|
||||
// the same type. This should have been verified by the verifier.
|
||||
|
@ -566,9 +565,8 @@ template <typename ElementwiseVariadicOp>
|
|||
struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||
ONNXElementwiseVariadicOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {}
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// TODO: Check that the types are valid.
|
||||
// An element-wise variadic operation must have all operands and the result
|
||||
// of the same type. This should have been verified by the verifier.
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
//===----------------- Gemm.cpp - Lowering Gemm Op -------------------------===//
|
||||
//===----------------- Gemm.cpp - Lowering Gemm Op
|
||||
//-------------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
|
@ -17,9 +18,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
|||
ONNXGemmOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(GemmOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
bool hasBias = !op->getOperand(2).getType().isa<NoneType>();
|
||||
|
||||
|
@ -32,12 +32,10 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
|||
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
|
||||
auto alphaAttr =
|
||||
FloatAttr::get(memRefType.getElementType(),
|
||||
llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat());
|
||||
auto betaAttr =
|
||||
FloatAttr::get(memRefType.getElementType(),
|
||||
llvm::dyn_cast<GemmOp>(op).beta().convertToFloat());
|
||||
auto alphaAttr = FloatAttr::get(memRefType.getElementType(),
|
||||
llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat());
|
||||
auto betaAttr = FloatAttr::get(memRefType.getElementType(),
|
||||
llvm::dyn_cast<GemmOp>(op).beta().convertToFloat());
|
||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
|
||||
|
||||
|
@ -101,8 +99,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
|||
optimizedReductionLoops.reserve(1);
|
||||
reductionLoops.push_back(originalLoops[2]);
|
||||
optimizedReductionLoops.push_back(optimizedLoops[2]);
|
||||
KrnlIterateOperandPack reductionPack(rewriter, reductionLoops,
|
||||
optimizedReductionLoops);
|
||||
KrnlIterateOperandPack reductionPack(
|
||||
rewriter, reductionLoops, optimizedReductionLoops);
|
||||
// Induction variable for the reduction dimension
|
||||
// Try to find and use a static value from A or B first.
|
||||
// If it failed then use a dynamic value.
|
||||
|
@ -167,8 +165,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
|||
auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
|
||||
auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
|
||||
if (hasBias) {
|
||||
auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C,
|
||||
broadcastedDimInfo);
|
||||
auto loopCIVs = getLoopIVsForBroadcasting(
|
||||
loc, rewriter, loopMNIVs, C, broadcastedDimInfo);
|
||||
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
|
||||
auto betaC = rewriter.create<MulFOp>(loc, beta, loadedC);
|
||||
auto Y = rewriter.create<AddFOp>(loc, alphaAB, betaC);
|
||||
|
@ -214,7 +212,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
|||
}
|
||||
};
|
||||
|
||||
void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
|
||||
MLIRContext *ctx) {
|
||||
void populateLoweringONNXGemmOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||
patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
|
||||
}
|
||||
|
|
|
@ -16,9 +16,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
|||
ONNXMatMulOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
|
||||
ONNXMatMulOpOperandAdaptor operandAdaptor(operands);
|
||||
|
@ -119,8 +118,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
|||
// Define loops for batch dimensions.
|
||||
std::vector<Value> originalLoops;
|
||||
std::vector<Value> optimizedLoops;
|
||||
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
|
||||
optimizedLoops, memRefShape.size());
|
||||
Block *optimizationBlock = defineLoops(
|
||||
rewriter, loc, originalLoops, optimizedLoops, memRefShape.size());
|
||||
|
||||
// Outer KrnlIterateOp
|
||||
SmallVector<Value, 4> loopBatchIVs;
|
||||
|
@ -139,8 +138,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
|||
outerLoops.push_back(originalLoops[i]);
|
||||
optimizedOuterLoops.push_back(optimizedLoops[i]);
|
||||
}
|
||||
KrnlIterateOperandPack outerPack(rewriter, outerLoops,
|
||||
optimizedOuterLoops);
|
||||
KrnlIterateOperandPack outerPack(
|
||||
rewriter, outerLoops, optimizedOuterLoops);
|
||||
for (int i = 0; i < batchAxes.size(); ++i) {
|
||||
addDimensionToPack(rewriter, loc, outerPack, alloc, i);
|
||||
}
|
||||
|
@ -176,11 +175,11 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
|||
optimizedMatmulLoops.emplace_back(
|
||||
optimizedLoops[memRefShape.size() - i]);
|
||||
}
|
||||
KrnlIterateOperandPack matmulPack(rewriter, matmulLoops,
|
||||
optimizedMatmulLoops);
|
||||
KrnlIterateOperandPack matmulPack(
|
||||
rewriter, matmulLoops, optimizedMatmulLoops);
|
||||
for (int i = 2; i > 0; --i) {
|
||||
addDimensionToPack(rewriter, loc, matmulPack, alloc,
|
||||
memRefShape.size() - i);
|
||||
addDimensionToPack(
|
||||
rewriter, loc, matmulPack, alloc, memRefShape.size() - i);
|
||||
}
|
||||
matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack);
|
||||
} else {
|
||||
|
@ -190,10 +189,10 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
|||
matmulLoops.emplace_back(originalLoops[memRefShape.size() - 1]);
|
||||
optimizedMatmulLoops.emplace_back(
|
||||
optimizedLoops[memRefShape.size() - 1]);
|
||||
KrnlIterateOperandPack matmulPack(rewriter, matmulLoops,
|
||||
optimizedMatmulLoops);
|
||||
addDimensionToPack(rewriter, loc, matmulPack, alloc,
|
||||
memRefShape.size() - 1);
|
||||
KrnlIterateOperandPack matmulPack(
|
||||
rewriter, matmulLoops, optimizedMatmulLoops);
|
||||
addDimensionToPack(
|
||||
rewriter, loc, matmulPack, alloc, memRefShape.size() - 1);
|
||||
matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack);
|
||||
}
|
||||
|
||||
|
@ -230,8 +229,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
|||
std::vector<Value> optimizedReduceLoops;
|
||||
Block *optimizationReduceBlock =
|
||||
defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1);
|
||||
KrnlIterateOperandPack reducePack(rewriter, reduceLoops,
|
||||
optimizedReduceLoops);
|
||||
KrnlIterateOperandPack reducePack(
|
||||
rewriter, reduceLoops, optimizedReduceLoops);
|
||||
addDimensionToPack(rewriter, loc, reducePack, A, AShape.size() - 1);
|
||||
auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack);
|
||||
|
||||
|
@ -292,8 +291,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
|||
std::vector<Value> optimizedReduceLoops;
|
||||
Block *optimizationReduceBlock =
|
||||
defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1);
|
||||
KrnlIterateOperandPack reducePack(rewriter, reduceLoops,
|
||||
optimizedReduceLoops);
|
||||
KrnlIterateOperandPack reducePack(
|
||||
rewriter, reduceLoops, optimizedReduceLoops);
|
||||
addDimensionToPack(rewriter, loc, reducePack, A, 0);
|
||||
auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack);
|
||||
|
||||
|
|
|
@ -102,9 +102,8 @@ struct ONNXReductionOpLowering : public ConversionPattern {
|
|||
ONNXReductionOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
/*
|
||||
* Condition: reduction function must be associative and commutative.
|
||||
*
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
//===--------------- Conv.cpp - Lowering Convolution Op --------------------===//
|
||||
//===--------------- Conv.cpp - Lowering Convolution Op
|
||||
//--------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
|
@ -175,14 +176,12 @@ struct ONNXConvOpLowering : public ConversionPattern {
|
|||
|
||||
// Emit the bias, if needed.
|
||||
if (hasBias) {
|
||||
auto loadResult =
|
||||
rewriter.create<LoadOp>(loc, alloc, resultIndices);
|
||||
auto loadResult = rewriter.create<LoadOp>(loc, alloc, resultIndices);
|
||||
SmallVector<Value, 4> biasIndices;
|
||||
biasIndices.emplace_back(kernel);
|
||||
auto loadBias =
|
||||
rewriter.create<LoadOp>(loc, biasOperand, kernel);
|
||||
auto resultWithBias = rewriter.create<MulFOp>(
|
||||
loc, loadResult, loadBias);
|
||||
auto loadBias = rewriter.create<LoadOp>(loc, biasOperand, kernel);
|
||||
auto resultWithBias =
|
||||
rewriter.create<MulFOp>(loc, loadResult, loadBias);
|
||||
// Store initializer value into output location.
|
||||
rewriter.create<StoreOp>(loc, resultWithBias, alloc, resultIndices);
|
||||
}
|
||||
|
|
|
@ -459,7 +459,8 @@ struct ONNXPoolOpLowering : public ConversionPattern {
|
|||
poolDimMap = AffineMap::get(1, 5, dimExpr, rewriter.getContext());
|
||||
|
||||
// poolStartMap and poolEndMap
|
||||
poolStartMap = AffineMap::get(1, 5, {start1, start2}, rewriter.getContext());
|
||||
poolStartMap =
|
||||
AffineMap::get(1, 5, {start1, start2}, rewriter.getContext());
|
||||
poolEndMap = AffineMap::get(1, 5, {end1, end2}, rewriter.getContext());
|
||||
}
|
||||
|
||||
|
|
|
@ -36,9 +36,7 @@ MemRefType convertToMemRefType(Type type) {
|
|||
|
||||
/// Insert an allocation and deallocation for the given MemRefType.
|
||||
Value insertAllocAndDealloc(MemRefType type, Location loc,
|
||||
PatternRewriter &rewriter,
|
||||
bool insertDealloc,
|
||||
ArrayRef<Value> operands) {
|
||||
PatternRewriter &rewriter, bool insertDealloc, ArrayRef<Value> operands) {
|
||||
// Put together alloc operands for any dynamic dimensions of the memref.
|
||||
AllocOp alloc;
|
||||
if (!operands.empty()) {
|
||||
|
@ -64,10 +62,10 @@ Value insertAllocAndDealloc(MemRefType type, Location loc,
|
|||
auto operandDim =
|
||||
rewriter.create<DimOp>(loc, operands[i], operandDimIdx);
|
||||
if (maxDim) {
|
||||
auto maxCondition = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt,
|
||||
operandDim, maxDim);
|
||||
maxDim = rewriter.create<SelectOp>(loc, maxCondition, operandDim,
|
||||
maxDim);
|
||||
auto maxCondition = rewriter.create<CmpIOp>(
|
||||
loc, CmpIPredicate::sgt, operandDim, maxDim);
|
||||
maxDim = rewriter.create<SelectOp>(
|
||||
loc, maxCondition, operandDim, maxDim);
|
||||
} else {
|
||||
maxDim = operandDim;
|
||||
}
|
||||
|
@ -122,8 +120,8 @@ bool checkInsertDealloc(Operation *currentOp, int resultIndex) {
|
|||
// Create a mapping from result type's dimensions to input type's dimensions,
|
||||
// given that the result type is the result of a reduction op over the input
|
||||
// type.
|
||||
std::map<int64_t, int64_t>
|
||||
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
|
||||
std::map<int64_t, int64_t> getReductionMapping(
|
||||
MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
|
||||
std::map<int64_t, int64_t> OutInDimMap;
|
||||
int64_t rank = inputTy.getRank();
|
||||
|
||||
|
@ -152,9 +150,8 @@ getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
|
|||
|
||||
// Add bounds associated with the op operand to the KRNL iteration pack.
|
||||
// Dynamic dimenions are supported.
|
||||
void addDimensionToPack(ConversionPatternRewriter &rewriter,
|
||||
Location loc, KrnlIterateOperandPack &pack,
|
||||
Value operand, int index) {
|
||||
void addDimensionToPack(ConversionPatternRewriter &rewriter, Location loc,
|
||||
KrnlIterateOperandPack &pack, Value operand, int index) {
|
||||
auto shape = operand.getType().cast<MemRefType>().getShape();
|
||||
if (shape[index] < 0) {
|
||||
pack.pushConstantBound(0);
|
||||
|
@ -168,10 +165,9 @@ void addDimensionToPack(ConversionPatternRewriter &rewriter,
|
|||
|
||||
// Function that defines the KRNL dialect loops and their respective
|
||||
// optimized version.
|
||||
KrnlOptimizeLoopsOp
|
||||
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
|
||||
std::vector<Value> &loops,
|
||||
std::vector<Value> &optimizedLoops, int64_t numLoops) {
|
||||
KrnlOptimizeLoopsOp emitOptimizedLoops(ConversionPatternRewriter &rewriter,
|
||||
Location loc, std::vector<Value> &loops, std::vector<Value> &optimizedLoops,
|
||||
int64_t numLoops) {
|
||||
// Define loops.
|
||||
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops);
|
||||
loops.reserve(numLoops);
|
||||
|
@ -190,9 +186,8 @@ emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
|
|||
// Function that emits the loops and their optimized version.
|
||||
// The function returns a reference to the inner optimization block.
|
||||
Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
|
||||
std::vector<Value> &loops,
|
||||
std::vector<Value> &optimizedLoops,
|
||||
int64_t numLoops) {
|
||||
std::vector<Value> &loops, std::vector<Value> &optimizedLoops,
|
||||
int64_t numLoops) {
|
||||
KrnlOptimizeLoopsOp optimizedLoopsOp =
|
||||
emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops);
|
||||
return &optimizedLoopsOp.region().front();
|
||||
|
@ -201,10 +196,9 @@ Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
|
|||
// Function which emits a basic set of loops and optimized loops
|
||||
// for a given operation argument. A reference to the loop optimization
|
||||
// block is returned in the last argument of the function.
|
||||
void emitKrnlLoopsAndIterationForOperand(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Value operand,
|
||||
std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
|
||||
KrnlIterateOp &iterateOp) {
|
||||
void emitKrnlLoopsAndIterationForOperand(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value operand, std::vector<Value> &originalLoops,
|
||||
KrnlOptimizeLoopsOp &optimizedLoopsOp, KrnlIterateOp &iterateOp) {
|
||||
// Operand shape.
|
||||
auto shape = operand.getType().cast<MemRefType>().getShape();
|
||||
|
||||
|
@ -240,9 +234,9 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
|
|||
|
||||
// Get run-time dimension information for unknown dimensions used for
|
||||
// broadcasting.
|
||||
std::map<int, std::map<int, Value>>
|
||||
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
||||
MemRefType memRefType, ArrayRef<Value> operands) {
|
||||
std::map<int, std::map<int, Value>> getBroadcastedDimInfo(Location loc,
|
||||
ConversionPatternRewriter &rewriter, MemRefType memRefType,
|
||||
ArrayRef<Value> operands) {
|
||||
auto memRefShape = memRefType.getShape();
|
||||
int64_t rank = memRefShape.size();
|
||||
// For unknown dimensions, we need to get dimension values at runtime in
|
||||
|
@ -286,10 +280,9 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
|||
|
||||
// Extract induction variables that are used for broadcasting values of a
|
||||
// given operand.
|
||||
std::vector<Value>
|
||||
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
||||
ArrayRef<Value> loopIVs, Value operand,
|
||||
std::map<int, Value> broadcastedDims) {
|
||||
std::vector<Value> getLoopIVsForBroadcasting(Location loc,
|
||||
ConversionPatternRewriter &rewriter, ArrayRef<Value> loopIVs, Value operand,
|
||||
std::map<int, Value> broadcastedDims) {
|
||||
// `operand` must has a ranked type. This should have been checked by the
|
||||
// shape inference pass.
|
||||
auto operandShape = operand.getType().cast<MemRefType>().getShape();
|
||||
|
@ -310,8 +303,8 @@ getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
|||
// If its value is 1, it is broadcasted dimension.
|
||||
// Otherwise, non-broadcasted dimension.
|
||||
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||
auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], zero,
|
||||
loopIVs[loopIdx]);
|
||||
auto idx = rewriter.create<SelectOp>(
|
||||
loc, broadcastedDims[dimIdx], zero, loopIVs[loopIdx]);
|
||||
newLoopIVs.insert(newLoopIVs.begin(), idx);
|
||||
} else {
|
||||
// Non-broadcasted dimension
|
||||
|
|
|
@ -30,7 +30,7 @@ struct ONNXConcatOpLowering : public ConversionPattern {
|
|||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
auto resultShape = memRefType.getShape();
|
||||
auto rank = resultShape.size();
|
||||
assert((axis >=0 && axis < rank) && "Concat axis out of bounds");
|
||||
assert((axis >= 0 && axis < rank) && "Concat axis out of bounds");
|
||||
|
||||
if (hasAllConstantDimensions(memRefType))
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||
|
|
|
@ -17,8 +17,8 @@ struct ONNXConstantOpLowering : public ConversionPattern {
|
|||
|
||||
ONNXConstantOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXConstantOp::getOperationName(), 1, ctx) {
|
||||
constantID = 0;
|
||||
}
|
||||
constantID = 0;
|
||||
}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
|
@ -34,12 +34,11 @@ struct ONNXConstantOpLowering : public ConversionPattern {
|
|||
// Shape based computations.
|
||||
auto shape = memRefType.getShape();
|
||||
int64_t numElements = 1;
|
||||
for (int i=0; i<shape.size(); ++i)
|
||||
for (int i = 0; i < shape.size(); ++i)
|
||||
numElements *= shape[i];
|
||||
|
||||
// Emit the constant global in Krnl dialect.
|
||||
auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc,
|
||||
memRefType,
|
||||
auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc, memRefType,
|
||||
rewriter.getI64ArrayAttr(shape),
|
||||
rewriter.getStringAttr("constant_" + std::to_string(constantID)),
|
||||
constantOp.value().getValue());
|
||||
|
|
|
@ -24,15 +24,15 @@ enum Kinds {
|
|||
}
|
||||
|
||||
class LoopType : public mlir::Type::TypeBase<LoopType, mlir::Type> {
|
||||
public:
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
// Support type inquiry through isa, cast and dyn_cast.
|
||||
static bool kindof(unsigned kind) { return kind == KrnlTypes::Loop; }
|
||||
|
||||
// Get a unique instance of Loop type.
|
||||
static LoopType get(mlir::MLIRContext* context) {
|
||||
static LoopType get(mlir::MLIRContext *context) {
|
||||
return Base::get(context, KrnlTypes::Loop);
|
||||
}
|
||||
};
|
||||
} // namespace mlir
|
||||
} // namespace mlir
|
||||
|
|
|
@ -39,7 +39,6 @@ MLONNXOpsDialect::MLONNXOpsDialect(mlir::MLIRContext *ctx)
|
|||
>();
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -19,14 +19,14 @@
|
|||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
#include "src/Interface/ShapeInferenceInterface.hpp"
|
||||
#include "src/Interface/PromotableConstOperandsOpInterface.hpp"
|
||||
#include "src/Interface/ShapeInferenceInterface.hpp"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class MLONNXOpsDialect : public Dialect {
|
||||
public:
|
||||
MLONNXOpsDialect(MLIRContext* context);
|
||||
public:
|
||||
MLONNXOpsDialect(MLIRContext *context);
|
||||
|
||||
/// Provide a utility accessor to the dialect namespace. This is used by
|
||||
/// several utilities for casting between dialects.
|
||||
|
@ -38,6 +38,6 @@ class MLONNXOpsDialect : public Dialect {
|
|||
#define GET_OP_CLASSES
|
||||
#include "src/Dialect/MLONNX/MLONNXOps.hpp.inc"
|
||||
|
||||
} // end namespace mlir
|
||||
} // end namespace mlir
|
||||
|
||||
namespace onnx_mlir {}
|
||||
|
|
|
@ -19,14 +19,14 @@
|
|||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
#include "src/Interface/ShapeInferenceInterface.hpp"
|
||||
#include "src/Interface/PromotableConstOperandsOpInterface.hpp"
|
||||
#include "src/Interface/ShapeInferenceInterface.hpp"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ONNXOpsDialect : public Dialect {
|
||||
public:
|
||||
ONNXOpsDialect(MLIRContext* context);
|
||||
public:
|
||||
ONNXOpsDialect(MLIRContext *context);
|
||||
|
||||
/// Provide a utility accessor to the dialect namespace. This is used by
|
||||
/// several utilities for casting between dialects.
|
||||
|
@ -38,6 +38,6 @@ class ONNXOpsDialect : public Dialect {
|
|||
#define GET_OP_CLASSES
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp.inc"
|
||||
|
||||
} // end namespace mlir
|
||||
} // end namespace mlir
|
||||
|
||||
namespace onnx_mlir {}
|
||||
|
|
|
@ -10,7 +10,6 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
#include "src/Interface/PromotableConstOperandsOpInterface.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -20,4 +19,3 @@ using namespace mlir;
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/Interface/PromotableConstOperandsOpInterface.cpp.inc"
|
||||
|
||||
|
|
|
@ -22,4 +22,4 @@ namespace mlir {
|
|||
/// Include the auto-generated declarations.
|
||||
#include "src/Interface/PromotableConstOperandsOpInterface.hpp.inc"
|
||||
|
||||
} // end namespace mlir
|
||||
} // end namespace mlir
|
|
@ -16,4 +16,4 @@ namespace mlir {
|
|||
/// Include the auto-generated declarations.
|
||||
#include "src/Interface/ShapeInference.cpp.inc"
|
||||
|
||||
} // end namespace mlir
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -18,4 +18,4 @@ namespace mlir {
|
|||
/// Include the auto-generated declarations.
|
||||
#include "src/Interface/ShapeInference.hpp.inc"
|
||||
|
||||
} // end namespace mlir
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
#ifdef _WIN32
|
||||
#include <io.h>
|
||||
#else
|
||||
#else
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
|
@ -22,7 +22,7 @@ using namespace std;
|
|||
using namespace onnx_mlir;
|
||||
|
||||
void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
|
||||
mlir::OwningModuleRef &module) {
|
||||
mlir::OwningModuleRef &module) {
|
||||
// Handle '.mlir' input to the ONNX MLIR frontend.
|
||||
// The mlir format indicates that one or more of the supported
|
||||
// representations are used in the file.
|
||||
|
@ -46,10 +46,10 @@ void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
|
|||
void EmitLLVMBitCode(
|
||||
const mlir::OwningModuleRef &module, string outputFilename) {
|
||||
error_code error;
|
||||
llvm::raw_fd_ostream moduleBitcodeStream(outputFilename, error,
|
||||
llvm::sys::fs::F_None);
|
||||
llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module),
|
||||
moduleBitcodeStream);
|
||||
llvm::raw_fd_ostream moduleBitcodeStream(
|
||||
outputFilename, error, llvm::sys::fs::F_None);
|
||||
llvm::WriteBitcodeToFile(
|
||||
*mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream);
|
||||
moduleBitcodeStream.flush();
|
||||
}
|
||||
|
||||
|
@ -90,7 +90,7 @@ void addKrnlToLLVMPasses(mlir::PassManager &pm) {
|
|||
}
|
||||
|
||||
void processInputFile(string inputFilename, EmissionTargetType emissionTarget,
|
||||
mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
|
||||
mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
|
||||
// Decide if the input file is an ONNX model or a model specified
|
||||
// in MLIR. The extension of the file is the decider.
|
||||
string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1);
|
||||
|
@ -99,7 +99,6 @@ void processInputFile(string inputFilename, EmissionTargetType emissionTarget,
|
|||
assert(inputIsONNX != inputIsMLIR &&
|
||||
"Either ONNX model or MLIR file needs to be provided.");
|
||||
|
||||
|
||||
if (inputIsONNX) {
|
||||
ImportFrontendModelFile(inputFilename, context, module);
|
||||
} else {
|
||||
|
@ -119,8 +118,8 @@ void outputCode(
|
|||
module->dump();
|
||||
fflush(stderr);
|
||||
// set modified stderr as original stderr
|
||||
_dup2(stderrOrigin, _fileno( stderr ));
|
||||
#else
|
||||
_dup2(stderrOrigin, _fileno(stderr));
|
||||
#else
|
||||
if (fork() == 0) {
|
||||
freopen(tempFilename.c_str(), "w", stderr);
|
||||
module->dump();
|
||||
|
@ -151,7 +150,7 @@ void emitOutputFiles(string outputBaseName, EmissionTargetType emissionTarget,
|
|||
// necessary when emitting the .bc file.
|
||||
if (emissionTarget == EmitLLVMBC) {
|
||||
// Write LLVM bitcode to disk.
|
||||
string outputFilename = outputBaseName + ".bc";
|
||||
string outputFilename = outputBaseName + ".bc";
|
||||
EmitLLVMBitCode(module, outputFilename);
|
||||
printf("LLVM bitcode written to %s\n", outputFilename.c_str());
|
||||
} else {
|
||||
|
|
|
@ -28,9 +28,9 @@
|
|||
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||
#include "mlir/InitAllDialects.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/InitAllDialects.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
|
@ -46,10 +46,10 @@ enum EmissionTargetType {
|
|||
};
|
||||
|
||||
void LoadMLIR(std::string inputFilename, mlir::MLIRContext &context,
|
||||
mlir::OwningModuleRef &module);
|
||||
mlir::OwningModuleRef &module);
|
||||
|
||||
void EmitLLVMBitCode(
|
||||
const mlir::OwningModuleRef &module, std::string outputFilename);
|
||||
const mlir::OwningModuleRef &module, std::string outputFilename);
|
||||
|
||||
void registerDialects();
|
||||
|
||||
|
@ -66,8 +66,7 @@ void processInputFile(std::string inputFilename,
|
|||
mlir::OwningModuleRef &module);
|
||||
|
||||
void outputCode(
|
||||
mlir::OwningModuleRef &module, std::string filename,
|
||||
std::string extension);
|
||||
mlir::OwningModuleRef &module, std::string filename, std::string extension);
|
||||
|
||||
void emitOutputFiles(std::string outputBaseName,
|
||||
EmissionTargetType emissionTarget, mlir::MLIRContext &context,
|
||||
|
|
|
@ -38,4 +38,4 @@ std::unique_ptr<Pass> createElideConstGlobalValuePass();
|
|||
/// Pass for lowering Krnl dialect to LLVM dialect.
|
||||
std::unique_ptr<Pass> createKrnlLowerToLLVMPass();
|
||||
|
||||
} // end namespace mlir
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
enum DYN_MEMREF_DATA_TYPE {
|
||||
UNDEFINED = 0;
|
||||
// Basic types.
|
||||
FLOAT = 1; // float
|
||||
UINT8 = 2; // uint8_t
|
||||
INT8 = 3; // int8_t
|
||||
UINT16 = 4; // uint16_t
|
||||
INT16 = 5; // int16_t
|
||||
INT32 = 6; // int32_t
|
||||
INT64 = 7; // int64_t
|
||||
STRING = 8; // string
|
||||
BOOL = 9; // bool
|
||||
FLOAT = 1; // float
|
||||
UINT8 = 2; // uint8_t
|
||||
INT8 = 3; // int8_t
|
||||
UINT16 = 4; // uint16_t
|
||||
INT16 = 5; // int16_t
|
||||
INT32 = 6; // int32_t
|
||||
INT64 = 7; // int64_t
|
||||
STRING = 8; // string
|
||||
BOOL = 9; // bool
|
||||
|
||||
// IEEE754 half-precision floating-point format (16 bits wide).
|
||||
// This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
|
||||
|
@ -18,8 +18,8 @@ enum DYN_MEMREF_DATA_TYPE {
|
|||
DOUBLE = 11;
|
||||
UINT32 = 12;
|
||||
UINT64 = 13;
|
||||
COMPLEX64 = 14; // complex with float32 real and imaginary components
|
||||
COMPLEX128 = 15; // complex with float64 real and imaginary components
|
||||
COMPLEX64 = 14; // complex with float32 real and imaginary components
|
||||
COMPLEX128 = 15; // complex with float64 real and imaginary components
|
||||
|
||||
// Non-IEEE floating-point format based on IEEE754 single-precision
|
||||
// floating-point number truncated to 16 bits.
|
||||
|
|
|
@ -33,8 +33,8 @@ DynMemRef *getDynMemRef(OrderedDynMemRefDict *tensorDict, int idx) {
|
|||
return tensorDict->tensorDict[tensorDict->orderedNames[idx]];
|
||||
}
|
||||
|
||||
void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx,
|
||||
DynMemRef *tensor) {
|
||||
void setDynMemRef(
|
||||
OrderedDynMemRefDict *tensorDict, int idx, DynMemRef *tensor) {
|
||||
if (tensorDict->orderedNames.size() <= idx)
|
||||
tensorDict->orderedNames.resize(idx + 1);
|
||||
|
||||
|
|
|
@ -38,7 +38,6 @@ typedef struct _OrderedDynMemRefDict OrderedDynMemRefDict;
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
|
||||
// Get number of dynamic memrefs in OrderedDynMemRefDict dict.
|
||||
int numDynMemRefs(OrderedDynMemRefDict *dict);
|
||||
|
@ -53,8 +52,8 @@ DynMemRef *createDynMemRef(int rank);
|
|||
DynMemRef *getDynMemRef(OrderedDynMemRefDict *orderedDict, int i);
|
||||
|
||||
// Set the i-th dynmemref in orderedDict to be dynMemRef.
|
||||
void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx,
|
||||
DynMemRef *dynMemRef);
|
||||
void setDynMemRef(
|
||||
OrderedDynMemRefDict *tensorDict, int idx, DynMemRef *dynMemRef);
|
||||
|
||||
// Get data pointer from dynMemRef.
|
||||
void *getData(DynMemRef *dynMemRef);
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
#include "Runtime.hpp"
|
||||
|
||||
ExecutionSession::ExecutionSession(std::string sharedLibPath,
|
||||
std::string entryPointName) {
|
||||
ExecutionSession::ExecutionSession(
|
||||
std::string sharedLibPath, std::string entryPointName) {
|
||||
_sharedLibraryHandle = dlopen(sharedLibPath.c_str(), RTLD_LAZY);
|
||||
_entryPointFunc =
|
||||
(entryPointFuncType)dlsym(_sharedLibraryHandle, entryPointName.c_str());
|
||||
}
|
||||
|
||||
std::vector<py::array>
|
||||
ExecutionSession::run(std::vector<py::array> inputsPyArray) {
|
||||
std::vector<py::array> ExecutionSession::run(
|
||||
std::vector<py::array> inputsPyArray) {
|
||||
assert(_entryPointFunc && "entry point not loaded");
|
||||
auto *wrappedInput = createOrderedDynMemRefDict();
|
||||
int inputIdx = 0;
|
||||
|
@ -40,8 +40,8 @@ ExecutionSession::run(std::vector<py::array> inputsPyArray) {
|
|||
auto *wrappedOutput = _entryPointFunc(wrappedInput);
|
||||
for (int i = 0; i < numDynMemRefs(wrappedOutput); i++) {
|
||||
auto *dynMemRef = getDynMemRef(wrappedOutput, i);
|
||||
auto shape = std::vector<int64_t>(dynMemRef->sizes,
|
||||
dynMemRef->sizes + dynMemRef->rank);
|
||||
auto shape = std::vector<int64_t>(
|
||||
dynMemRef->sizes, dynMemRef->sizes + dynMemRef->rank);
|
||||
outputPyArrays.emplace_back(
|
||||
py::array(py::dtype("float32"), shape, dynMemRef->data));
|
||||
}
|
||||
|
|
|
@ -144,9 +144,9 @@ public:
|
|||
|
||||
assert(krnlGlobalOp.value().hasValue() &&
|
||||
"Krnl Global must always have a value");
|
||||
global = rewriter.create<LLVM::GlobalOp>(loc,
|
||||
llvmGlobalType, /*isConstant=*/true,
|
||||
LLVM::Linkage::Internal, name, krnlGlobalOp.value().getValue());
|
||||
global = rewriter.create<LLVM::GlobalOp>(loc, llvmGlobalType,
|
||||
/*isConstant=*/true, LLVM::Linkage::Internal, name,
|
||||
krnlGlobalOp.value().getValue());
|
||||
}
|
||||
|
||||
// Some frequently used types.
|
||||
|
|
|
@ -75,7 +75,7 @@ public:
|
|||
OwningRewritePatternList patterns;
|
||||
auto *context = &getContext();
|
||||
ConstantOp::getCanonicalizationPatterns(patterns, context);
|
||||
applyPatternsAndFoldGreedily(f, patterns);
|
||||
applyPatternsAndFoldGreedily(f, patterns);
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
|
|
@ -12,35 +12,35 @@
|
|||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include <numeric>
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
/// Include the patterns defined in the Declarative Rewrite framework.
|
||||
#include "src/Transform/ONNX/ONNXCombine.inc"
|
||||
} // end anonymous namespace
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Register optimization patterns as "canonicalization" patterns
|
||||
/// on the ONNXMatMultOp.
|
||||
void ONNXAddOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList& results, MLIRContext* context) {
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<MulAddToGemmOptPattern>(context);
|
||||
}
|
||||
|
||||
void ONNXGemmOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList& results, MLIRContext* context) {
|
||||
results.insert<FuseGemmFollowedByAddition>(context);
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<FuseGemmFollowedByAddition>(context);
|
||||
}
|
||||
/// on the ONNXIdentityOp.
|
||||
void ONNXIdentityOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList& results, MLIRContext* context) {
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<IdentityEliminationPattern>(context);
|
||||
}
|
||||
|
||||
///on the ONNXPadConstantValueOp.
|
||||
/// on the ONNXPadConstantValueOp.
|
||||
void ONNXPadConstantValueOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList& result, MLIRContext* context) {
|
||||
OwningRewritePatternList &result, MLIRContext *context) {
|
||||
result.insert<ConstantPadPattern>(context);
|
||||
}
|
||||
|
|
|
@ -27,7 +27,8 @@ namespace {
|
|||
/// Include the patterns defined in the Declarative Rewrite framework.
|
||||
#include "src/Transform/ONNX/ONNXDecompose.inc"
|
||||
|
||||
struct DecomposeONNXToONNXPass : public PassWrapper<DecomposeONNXToONNXPass, FunctionPass> {
|
||||
struct DecomposeONNXToONNXPass
|
||||
: public PassWrapper<DecomposeONNXToONNXPass, FunctionPass> {
|
||||
void runOnFunction() final;
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
|
|
@ -9,10 +9,10 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
#include "src/Interface/ShapeInferenceInterface.hpp"
|
||||
#include "src/Pass/Passes.hpp"
|
||||
|
@ -25,7 +25,8 @@ namespace {
|
|||
* candidate operations and propagating the shape information until the list
|
||||
* of operations is empty [credit MLIR authors].
|
||||
*/
|
||||
class ShapeInferencePass : public mlir::PassWrapper<ShapeInferencePass, mlir::FunctionPass> {
|
||||
class ShapeInferencePass
|
||||
: public mlir::PassWrapper<ShapeInferencePass, mlir::FunctionPass> {
|
||||
public:
|
||||
void runOnFunction() override {
|
||||
auto f = getFunction();
|
||||
|
@ -63,8 +64,7 @@ public:
|
|||
|
||||
if (auto terminator_op = f.getBody().back().getTerminator()) {
|
||||
auto results = terminator_op->getOperandTypes();
|
||||
f.setType(FunctionType::get(
|
||||
f.getType().getInputs(),
|
||||
f.setType(FunctionType::get(f.getType().getInputs(),
|
||||
std::vector<Type>(results.begin(), results.end()), f.getContext()));
|
||||
}
|
||||
}
|
||||
|
@ -146,5 +146,5 @@ std::unique_ptr<mlir::Pass> mlir::createShapeInferencePass() {
|
|||
return std::make_unique<ShapeInferencePass>();
|
||||
}
|
||||
|
||||
static PassRegistration<ShapeInferencePass>
|
||||
pass("shape-inference", "Shape inference for frontend dialects.");
|
||||
static PassRegistration<ShapeInferencePass> pass(
|
||||
"shape-inference", "Shape inference for frontend dialects.");
|
||||
|
|
24
src/main.cpp
24
src/main.cpp
|
@ -14,30 +14,30 @@ using namespace onnx_mlir;
|
|||
int main(int argc, char *argv[]) {
|
||||
registerDialects();
|
||||
|
||||
llvm::cl::OptionCategory OnnxMlirOptions("ONNX MLIR Options",
|
||||
"These are frontend options.");
|
||||
llvm::cl::opt<string> inputFilename(
|
||||
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"),
|
||||
llvm::cl::OptionCategory OnnxMlirOptions(
|
||||
"ONNX MLIR Options", "These are frontend options.");
|
||||
llvm::cl::opt<string> inputFilename(llvm::cl::Positional,
|
||||
llvm::cl::desc("<input file>"), llvm::cl::init("-"),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<EmissionTargetType> emissionTarget(
|
||||
llvm::cl::desc("Choose target to emit:"),
|
||||
llvm::cl::values(
|
||||
clEnumVal(EmitONNXBasic,
|
||||
"Ingest ONNX and emit the basic ONNX operations without"
|
||||
"inferred shapes."),
|
||||
clEnumVal(EmitONNXIR,
|
||||
"Ingest ONNX and emit corresponding ONNX dialect."),
|
||||
clEnumVal(EmitMLIR,
|
||||
"Lower model to MLIR built-in transformation dialect."),
|
||||
"Ingest ONNX and emit the basic ONNX operations without"
|
||||
"inferred shapes."),
|
||||
clEnumVal(
|
||||
EmitONNXIR, "Ingest ONNX and emit corresponding ONNX dialect."),
|
||||
clEnumVal(
|
||||
EmitMLIR, "Lower model to MLIR built-in transformation dialect."),
|
||||
clEnumVal(EmitLLVMIR, "Lower model to LLVM IR (LLVM dialect)."),
|
||||
clEnumVal(EmitLLVMBC, "Lower model to LLVM IR and emit (to file) "
|
||||
"LLVM bitcode for model.")),
|
||||
llvm::cl::init(EmitLLVMBC), llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::HideUnrelatedOptions(OnnxMlirOptions);
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv,
|
||||
"ONNX MLIR modular optimizer driver\n");
|
||||
llvm::cl::ParseCommandLineOptions(
|
||||
argc, argv, "ONNX MLIR modular optimizer driver\n");
|
||||
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module;
|
||||
|
|
Loading…
Reference in New Issue