[NFC] Set up clang-format Github Action (#119)
* Run clang-format on all source code. * Add Clang-Format Github Action. * Apply patch produced by Clang-Format Bot. * nit. Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
24343177b8
commit
7f2bffb27d
|
@ -0,0 +1,40 @@
|
||||||
|
# This is a basic workflow to help you get started with Actions
|
||||||
|
|
||||||
|
name: Clang-Format Bot
|
||||||
|
|
||||||
|
# Controls when the action will run. Triggers the workflow on push or pull request
|
||||||
|
# events but only for the master branch
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ master ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ master ]
|
||||||
|
|
||||||
|
# A workflow run is made up of one or more jobs that can run sequentially or in parallel
|
||||||
|
jobs:
|
||||||
|
# This workflow contains a single job called "build"
|
||||||
|
build:
|
||||||
|
# The type of runner that the job will run on
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
# Steps represent a sequence of tasks that will be executed as part of the job
|
||||||
|
steps:
|
||||||
|
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: clang-format lint
|
||||||
|
uses: DoozyX/clang-format-lint-action@v0.5
|
||||||
|
with:
|
||||||
|
# Source folder to check formatting
|
||||||
|
source: ./src
|
||||||
|
# Version of clang-format
|
||||||
|
clangFormatVersion: 9 # optional, default is 9
|
||||||
|
|
||||||
|
# Runs a single command using the runners shell
|
||||||
|
- name: Run a one-line script
|
||||||
|
run: echo Hello, world!
|
||||||
|
|
||||||
|
# Runs a set of commands using the runners shell
|
||||||
|
- name: Run a multi-line script
|
||||||
|
run: |
|
||||||
|
echo Add other actions to build,
|
||||||
|
echo test, and deploy your project.
|
|
@ -12,8 +12,8 @@
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
void replaceAll(std::string &str, const std::string &from,
|
void replaceAll(
|
||||||
const std::string &to) {
|
std::string &str, const std::string &from, const std::string &to) {
|
||||||
if (from.empty())
|
if (from.empty())
|
||||||
return;
|
return;
|
||||||
size_t start_pos = 0;
|
size_t start_pos = 0;
|
||||||
|
@ -121,7 +121,6 @@ void InitializedTensorMapping::AddMapping(
|
||||||
nameToInitializedTensor.emplace(name, tensor);
|
nameToInitializedTensor.emplace(name, tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
bool InitializedTensorMapping::ContainKey(std::string name) {
|
bool InitializedTensorMapping::ContainKey(std::string name) {
|
||||||
return nameToInitializedTensor.count(name) != 0;
|
return nameToInitializedTensor.count(name) != 0;
|
||||||
}
|
}
|
||||||
|
@ -132,8 +131,8 @@ mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor(
|
||||||
onnx::TensorProto initializer = GetInitializedTensor(name);
|
onnx::TensorProto initializer = GetInitializedTensor(name);
|
||||||
|
|
||||||
// Tensor dimensions.
|
// Tensor dimensions.
|
||||||
llvm::ArrayRef<int64_t> tensorDims(initializer.dims().data(),
|
llvm::ArrayRef<int64_t> tensorDims(
|
||||||
initializer.dims().size());
|
initializer.dims().data(), initializer.dims().size());
|
||||||
|
|
||||||
// Emit ConstantOp and record the mapping between the input and
|
// Emit ConstantOp and record the mapping between the input and
|
||||||
// the constant value.
|
// the constant value.
|
||||||
|
@ -143,8 +142,7 @@ mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor(
|
||||||
mlir::ShapedType tensorType;
|
mlir::ShapedType tensorType;
|
||||||
switch (initializer.data_type()) {
|
switch (initializer.data_type()) {
|
||||||
case (onnx::TensorProto::FLOAT): {
|
case (onnx::TensorProto::FLOAT): {
|
||||||
const auto& arrayAttrInitializer =
|
const auto &arrayAttrInitializer = CreateArrayAttribute<float>(initializer);
|
||||||
CreateArrayAttribute<float>(initializer);
|
|
||||||
elementType = builder.getF32Type();
|
elementType = builder.getF32Type();
|
||||||
tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
|
tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
|
||||||
constantDenseAttribute = mlir::DenseElementsAttr::get(
|
constantDenseAttribute = mlir::DenseElementsAttr::get(
|
||||||
|
|
|
@ -20,8 +20,8 @@
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/Function.h"
|
#include "mlir/IR/Function.h"
|
||||||
#include "mlir/IR/Location.h"
|
#include "mlir/IR/Location.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/IR/Module.h"
|
#include "mlir/IR/Module.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
@ -37,11 +37,12 @@
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "onnx/onnx_pb.h"
|
#include "onnx/onnx_pb.h"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
void replaceAll(std::string &str, const std::string &from,
|
void replaceAll(
|
||||||
const std::string &to);
|
std::string &str, const std::string &from, const std::string &to);
|
||||||
|
|
||||||
std::string legalize_name(std::string name);
|
std::string legalize_name(std::string name);
|
||||||
|
|
||||||
|
@ -86,13 +87,13 @@ struct InitializedTensorMapping {
|
||||||
// This will allow the propagation of shape information passed in as an
|
// This will allow the propagation of shape information passed in as an
|
||||||
// argument to operations such as Reshape and will enable other
|
// argument to operations such as Reshape and will enable other
|
||||||
// optimizations such as constant folding.
|
// optimizations such as constant folding.
|
||||||
mlir::Value EmitInitializerForInputTensor(mlir::Location loc,
|
mlir::Value EmitInitializerForInputTensor(
|
||||||
mlir::OpBuilder &builder, std::string name);
|
mlir::Location loc, mlir::OpBuilder &builder, std::string name);
|
||||||
|
|
||||||
// Get initialized tensor.
|
// Get initialized tensor.
|
||||||
onnx::TensorProto &GetInitializedTensor(std::string name) {
|
onnx::TensorProto &GetInitializedTensor(std::string name) {
|
||||||
assert(nameToInitializedTensor.find(name) !=
|
assert(
|
||||||
nameToInitializedTensor.end() &&
|
nameToInitializedTensor.find(name) != nameToInitializedTensor.end() &&
|
||||||
"Tensor initializer not found");
|
"Tensor initializer not found");
|
||||||
return nameToInitializedTensor.at(name);
|
return nameToInitializedTensor.at(name);
|
||||||
}
|
}
|
||||||
|
|
|
@ -127,8 +127,8 @@ private:
|
||||||
* @param input onnx input tensor ValueInfoProto.
|
* @param input onnx input tensor ValueInfoProto.
|
||||||
* @param symbol mlir input argument.
|
* @param symbol mlir input argument.
|
||||||
*/
|
*/
|
||||||
void ImportInputTensorSymbol(const onnx::ValueInfoProto &input,
|
void ImportInputTensorSymbol(
|
||||||
mlir::Value symbol) {
|
const onnx::ValueInfoProto &input, mlir::Value symbol) {
|
||||||
auto input_tensor_legalized_name = legalize_name(input.name());
|
auto input_tensor_legalized_name = legalize_name(input.name());
|
||||||
assert(!frontend_symbols_.ContainKey(input_tensor_legalized_name) &&
|
assert(!frontend_symbols_.ContainKey(input_tensor_legalized_name) &&
|
||||||
"Found duplicate legalized input tensor names.");
|
"Found duplicate legalized input tensor names.");
|
||||||
|
@ -136,8 +136,7 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef bstd::variant<int64_t, std::vector<int64_t>, float,
|
typedef bstd::variant<int64_t, std::vector<int64_t>, float,
|
||||||
std::vector<float>, std::string,
|
std::vector<float>, std::string, std::vector<std::string>>
|
||||||
std::vector<std::string>>
|
|
||||||
AttrValueType;
|
AttrValueType;
|
||||||
|
|
||||||
struct ONNXAttrVisitor {
|
struct ONNXAttrVisitor {
|
||||||
|
@ -213,8 +212,8 @@ private:
|
||||||
llvm_unreachable("Failed to convert attribute proto to name/value pair");
|
llvm_unreachable("Failed to convert attribute proto to name/value pair");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<mlir::NamedAttribute>
|
std::vector<mlir::NamedAttribute> ImportNodeAttributes(
|
||||||
ImportNodeAttributes(const onnx::NodeProto &node) {
|
const onnx::NodeProto &node) {
|
||||||
std::vector<mlir::NamedAttribute> attributes;
|
std::vector<mlir::NamedAttribute> attributes;
|
||||||
for (int i = 0; i < node.attribute_size(); ++i) {
|
for (int i = 0; i < node.attribute_size(); ++i) {
|
||||||
auto attr = node.attribute(i);
|
auto attr = node.attribute(i);
|
||||||
|
@ -281,14 +280,13 @@ private:
|
||||||
// TODO: Handle optional inputs.
|
// TODO: Handle optional inputs.
|
||||||
auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
|
auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
|
||||||
for (int i = 0; i < node.output().size(); i++) {
|
for (int i = 0; i < node.output().size(); i++) {
|
||||||
frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
|
frontend_symbols_.AddMapping(
|
||||||
*(op.getODSResults(i).begin()));
|
legalize_name(node.output()[i]), *(op.getODSResults(i).begin()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void buildOperation(const onnx::NodeProto &node,
|
void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1,
|
||||||
int expectedNumOperands = -1,
|
|
||||||
int expectedNumResults = -1) {
|
int expectedNumResults = -1) {
|
||||||
std::vector<mlir::Value> inputs;
|
std::vector<mlir::Value> inputs;
|
||||||
for (const auto &item : node.input())
|
for (const auto &item : node.input())
|
||||||
|
@ -299,8 +297,8 @@ private:
|
||||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||||
}
|
}
|
||||||
|
|
||||||
buildOutputAndOperation<T>(node, inputs, expectedNumOperands,
|
buildOutputAndOperation<T>(
|
||||||
expectedNumResults);
|
node, inputs, expectedNumOperands, expectedNumResults);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ImportNodeReshape(onnx::NodeProto node, int nIn, int nOut) {
|
void ImportNodeReshape(onnx::NodeProto node, int nIn, int nOut) {
|
||||||
|
@ -310,8 +308,7 @@ private:
|
||||||
item = node.input()[i];
|
item = node.input()[i];
|
||||||
// For the second argument, check if there exists an initializer.
|
// For the second argument, check if there exists an initializer.
|
||||||
if (initializedTensors.ContainKey(legalize_name(item))) {
|
if (initializedTensors.ContainKey(legalize_name(item))) {
|
||||||
inputs.push_back(
|
inputs.push_back(initializedTensors.EmitInitializerForInputTensor(
|
||||||
initializedTensors.EmitInitializerForInputTensor(
|
|
||||||
UnknownLoc(), builder_, legalize_name(item)));
|
UnknownLoc(), builder_, legalize_name(item)));
|
||||||
} else if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
} else if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||||
|
@ -372,7 +369,6 @@ private:
|
||||||
#if INCLUDE_ONNX_ML == 1
|
#if INCLUDE_ONNX_ML == 1
|
||||||
#include "src/Builder/MLOpBuildTable.inc"
|
#include "src/Builder/MLOpBuildTable.inc"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
|
@ -400,8 +396,8 @@ private:
|
||||||
ret_vals.push_back(tensor_val);
|
ret_vals.push_back(tensor_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ImportGraph(const onnx::GraphProto &graph,
|
void ImportGraph(
|
||||||
const std::string &name = "main_graph") {
|
const onnx::GraphProto &graph, const std::string &name = "main_graph") {
|
||||||
// Maintain a mapping between the parameter and its initializer.
|
// Maintain a mapping between the parameter and its initializer.
|
||||||
for (auto initializer : graph.initializer()) {
|
for (auto initializer : graph.initializer()) {
|
||||||
auto name = initializer.name();
|
auto name = initializer.name();
|
||||||
|
@ -426,8 +422,7 @@ private:
|
||||||
|
|
||||||
// Emit the entry point operation which specifies the number of user
|
// Emit the entry point operation which specifies the number of user
|
||||||
// inputs and outputs.
|
// inputs and outputs.
|
||||||
auto entryPoint = mlir::ONNXEntryPointOp::create(
|
auto entryPoint = mlir::ONNXEntryPointOp::create(UnknownLoc(), mainFunc,
|
||||||
UnknownLoc(), mainFunc,
|
|
||||||
/*numInputs=*/graph.input().size() - graph.initializer().size(),
|
/*numInputs=*/graph.input().size() - graph.initializer().size(),
|
||||||
/*numOutputs=*/graph.output().size());
|
/*numOutputs=*/graph.output().size());
|
||||||
|
|
||||||
|
@ -454,8 +449,8 @@ private:
|
||||||
|
|
||||||
// Create a NoneTyped constant to be used for optional operation inputs
|
// Create a NoneTyped constant to be used for optional operation inputs
|
||||||
// which are not used.
|
// which are not used.
|
||||||
none_ = builder_.create<mlir::ConstantOp>(UnknownLoc(),
|
none_ =
|
||||||
builder_.getUnitAttr());
|
builder_.create<mlir::ConstantOp>(UnknownLoc(), builder_.getUnitAttr());
|
||||||
|
|
||||||
// Import nodes in the graph.
|
// Import nodes in the graph.
|
||||||
for (const auto &item : graph.node()) {
|
for (const auto &item : graph.node()) {
|
||||||
|
@ -483,8 +478,7 @@ private:
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
void ImportFrontendModelFile(std::string model_fname,
|
void ImportFrontendModelFile(std::string model_fname,
|
||||||
mlir::MLIRContext &context,
|
mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
|
||||||
mlir::OwningModuleRef &module) {
|
|
||||||
onnx::ModelProto model;
|
onnx::ModelProto model;
|
||||||
std::fstream input(model_fname, std::ios::in | std::ios::binary);
|
std::fstream input(model_fname, std::ios::in | std::ios::binary);
|
||||||
|
|
||||||
|
|
|
@ -36,8 +36,7 @@ namespace onnx_mlir {
|
||||||
* @return MLIR::module generated for the ONNX model.
|
* @return MLIR::module generated for the ONNX model.
|
||||||
*/
|
*/
|
||||||
void ImportFrontendModelFile(std::string model_fname,
|
void ImportFrontendModelFile(std::string model_fname,
|
||||||
mlir::MLIRContext &context,
|
mlir::MLIRContext &context, mlir::OwningModuleRef &module);
|
||||||
mlir::OwningModuleRef &module);
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* TODO: Import models into other extension dialects that cover the
|
* TODO: Import models into other extension dialects that cover the
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
//====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering --------===//
|
//====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering
|
||||||
|
//--------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -21,10 +22,9 @@ class ONNXEntryPointLowering : public OpRewritePattern<ONNXEntryPointOp> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern<ONNXEntryPointOp>::OpRewritePattern;
|
using OpRewritePattern<ONNXEntryPointOp>::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(ONNXEntryPointOp op,
|
LogicalResult matchAndRewrite(
|
||||||
PatternRewriter &rewriter) const override {
|
ONNXEntryPointOp op, PatternRewriter &rewriter) const override {
|
||||||
rewriter.replaceOpWithNewOp<KrnlEntryPointOp>(
|
rewriter.replaceOpWithNewOp<KrnlEntryPointOp>(op,
|
||||||
op,
|
|
||||||
op.getAttrOfType<SymbolRefAttr>(
|
op.getAttrOfType<SymbolRefAttr>(
|
||||||
ONNXEntryPointOp::getEntryPointFuncAttrName()),
|
ONNXEntryPointOp::getEntryPointFuncAttrName()),
|
||||||
op.getAttrOfType<IntegerAttr>(ONNXEntryPointOp::getNumInputsAttrName()),
|
op.getAttrOfType<IntegerAttr>(ONNXEntryPointOp::getNumInputsAttrName()),
|
||||||
|
@ -55,8 +55,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
|
||||||
|
|
||||||
// We define the specific operations, or dialects, that are legal targets for
|
// We define the specific operations, or dialects, that are legal targets for
|
||||||
// this lowering.
|
// this lowering.
|
||||||
target
|
target.addLegalDialect<KrnlOpsDialect, AffineDialect, StandardOpsDialect>();
|
||||||
.addLegalDialect<KrnlOpsDialect, AffineDialect, StandardOpsDialect>();
|
|
||||||
|
|
||||||
// TODO: enable this once more ops are supported.
|
// TODO: enable this once more ops are supported.
|
||||||
// We also define the ONNX dialect as Illegal so that the conversion will fail
|
// We also define the ONNX dialect as Illegal so that the conversion will fail
|
||||||
|
@ -81,8 +80,8 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
|
||||||
// Type conversion for function signatures.
|
// Type conversion for function signatures.
|
||||||
// Call MLIR FuncOp signature conversion when result type is
|
// Call MLIR FuncOp signature conversion when result type is
|
||||||
// a ranked tensor.
|
// a ranked tensor.
|
||||||
populateFuncOpTypeConversionPattern(patterns, &getContext(),
|
populateFuncOpTypeConversionPattern(
|
||||||
tensor_to_memref_converter);
|
patterns, &getContext(), tensor_to_memref_converter);
|
||||||
|
|
||||||
// Frontend operation lowering.
|
// Frontend operation lowering.
|
||||||
// Math
|
// Math
|
||||||
|
@ -119,5 +118,5 @@ std::unique_ptr<Pass> mlir::createLowerToKrnlPass() {
|
||||||
return std::make_unique<FrontendToKrnlLoweringPass>();
|
return std::make_unique<FrontendToKrnlLoweringPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<FrontendToKrnlLoweringPass>
|
static PassRegistration<FrontendToKrnlLoweringPass> pass(
|
||||||
pass("lower-frontend", "Lower frontend ops to Krnl dialect.");
|
"lower-frontend", "Lower frontend ops to Krnl dialect.");
|
||||||
|
|
|
@ -499,8 +499,7 @@ template <typename ElementwiseUnaryOp>
|
||||||
struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
||||||
ONNXElementwiseUnaryOpLowering(MLIRContext *ctx)
|
ONNXElementwiseUnaryOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {}
|
||||||
LogicalResult
|
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
// TODO: Check that the types are valid.
|
// TODO: Check that the types are valid.
|
||||||
// An element-wise unary operation must have all operands and the result of
|
// An element-wise unary operation must have all operands and the result of
|
||||||
|
@ -566,8 +565,7 @@ template <typename ElementwiseVariadicOp>
|
||||||
struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
ONNXElementwiseVariadicOpLowering(MLIRContext *ctx)
|
ONNXElementwiseVariadicOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {}
|
||||||
LogicalResult
|
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
// TODO: Check that the types are valid.
|
// TODO: Check that the types are valid.
|
||||||
// An element-wise variadic operation must have all operands and the result
|
// An element-wise variadic operation must have all operands and the result
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
//===----------------- Gemm.cpp - Lowering Gemm Op -------------------------===//
|
//===----------------- Gemm.cpp - Lowering Gemm Op
|
||||||
|
//-------------------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -17,8 +18,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
ONNXGemmOpLowering(MLIRContext *ctx)
|
ONNXGemmOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(GemmOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(GemmOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
bool hasBias = !op->getOperand(2).getType().isa<NoneType>();
|
bool hasBias = !op->getOperand(2).getType().isa<NoneType>();
|
||||||
|
@ -32,11 +32,9 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
|
|
||||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
|
|
||||||
auto alphaAttr =
|
auto alphaAttr = FloatAttr::get(memRefType.getElementType(),
|
||||||
FloatAttr::get(memRefType.getElementType(),
|
|
||||||
llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat());
|
llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat());
|
||||||
auto betaAttr =
|
auto betaAttr = FloatAttr::get(memRefType.getElementType(),
|
||||||
FloatAttr::get(memRefType.getElementType(),
|
|
||||||
llvm::dyn_cast<GemmOp>(op).beta().convertToFloat());
|
llvm::dyn_cast<GemmOp>(op).beta().convertToFloat());
|
||||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||||
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
|
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
|
||||||
|
@ -101,8 +99,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
optimizedReductionLoops.reserve(1);
|
optimizedReductionLoops.reserve(1);
|
||||||
reductionLoops.push_back(originalLoops[2]);
|
reductionLoops.push_back(originalLoops[2]);
|
||||||
optimizedReductionLoops.push_back(optimizedLoops[2]);
|
optimizedReductionLoops.push_back(optimizedLoops[2]);
|
||||||
KrnlIterateOperandPack reductionPack(rewriter, reductionLoops,
|
KrnlIterateOperandPack reductionPack(
|
||||||
optimizedReductionLoops);
|
rewriter, reductionLoops, optimizedReductionLoops);
|
||||||
// Induction variable for the reduction dimension
|
// Induction variable for the reduction dimension
|
||||||
// Try to find and use a static value from A or B first.
|
// Try to find and use a static value from A or B first.
|
||||||
// If it failed then use a dynamic value.
|
// If it failed then use a dynamic value.
|
||||||
|
@ -167,8 +165,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
|
auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
|
||||||
auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
|
auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
|
||||||
if (hasBias) {
|
if (hasBias) {
|
||||||
auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C,
|
auto loopCIVs = getLoopIVsForBroadcasting(
|
||||||
broadcastedDimInfo);
|
loc, rewriter, loopMNIVs, C, broadcastedDimInfo);
|
||||||
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
|
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
|
||||||
auto betaC = rewriter.create<MulFOp>(loc, beta, loadedC);
|
auto betaC = rewriter.create<MulFOp>(loc, beta, loadedC);
|
||||||
auto Y = rewriter.create<AddFOp>(loc, alphaAB, betaC);
|
auto Y = rewriter.create<AddFOp>(loc, alphaAB, betaC);
|
||||||
|
@ -214,7 +212,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
|
void populateLoweringONNXGemmOpPattern(
|
||||||
MLIRContext *ctx) {
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
|
patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,8 +16,7 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
||||||
ONNXMatMulOpLowering(MLIRContext *ctx)
|
ONNXMatMulOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
|
@ -119,8 +118,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
||||||
// Define loops for batch dimensions.
|
// Define loops for batch dimensions.
|
||||||
std::vector<Value> originalLoops;
|
std::vector<Value> originalLoops;
|
||||||
std::vector<Value> optimizedLoops;
|
std::vector<Value> optimizedLoops;
|
||||||
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
|
Block *optimizationBlock = defineLoops(
|
||||||
optimizedLoops, memRefShape.size());
|
rewriter, loc, originalLoops, optimizedLoops, memRefShape.size());
|
||||||
|
|
||||||
// Outer KrnlIterateOp
|
// Outer KrnlIterateOp
|
||||||
SmallVector<Value, 4> loopBatchIVs;
|
SmallVector<Value, 4> loopBatchIVs;
|
||||||
|
@ -139,8 +138,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
||||||
outerLoops.push_back(originalLoops[i]);
|
outerLoops.push_back(originalLoops[i]);
|
||||||
optimizedOuterLoops.push_back(optimizedLoops[i]);
|
optimizedOuterLoops.push_back(optimizedLoops[i]);
|
||||||
}
|
}
|
||||||
KrnlIterateOperandPack outerPack(rewriter, outerLoops,
|
KrnlIterateOperandPack outerPack(
|
||||||
optimizedOuterLoops);
|
rewriter, outerLoops, optimizedOuterLoops);
|
||||||
for (int i = 0; i < batchAxes.size(); ++i) {
|
for (int i = 0; i < batchAxes.size(); ++i) {
|
||||||
addDimensionToPack(rewriter, loc, outerPack, alloc, i);
|
addDimensionToPack(rewriter, loc, outerPack, alloc, i);
|
||||||
}
|
}
|
||||||
|
@ -176,11 +175,11 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
||||||
optimizedMatmulLoops.emplace_back(
|
optimizedMatmulLoops.emplace_back(
|
||||||
optimizedLoops[memRefShape.size() - i]);
|
optimizedLoops[memRefShape.size() - i]);
|
||||||
}
|
}
|
||||||
KrnlIterateOperandPack matmulPack(rewriter, matmulLoops,
|
KrnlIterateOperandPack matmulPack(
|
||||||
optimizedMatmulLoops);
|
rewriter, matmulLoops, optimizedMatmulLoops);
|
||||||
for (int i = 2; i > 0; --i) {
|
for (int i = 2; i > 0; --i) {
|
||||||
addDimensionToPack(rewriter, loc, matmulPack, alloc,
|
addDimensionToPack(
|
||||||
memRefShape.size() - i);
|
rewriter, loc, matmulPack, alloc, memRefShape.size() - i);
|
||||||
}
|
}
|
||||||
matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack);
|
matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack);
|
||||||
} else {
|
} else {
|
||||||
|
@ -190,10 +189,10 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
||||||
matmulLoops.emplace_back(originalLoops[memRefShape.size() - 1]);
|
matmulLoops.emplace_back(originalLoops[memRefShape.size() - 1]);
|
||||||
optimizedMatmulLoops.emplace_back(
|
optimizedMatmulLoops.emplace_back(
|
||||||
optimizedLoops[memRefShape.size() - 1]);
|
optimizedLoops[memRefShape.size() - 1]);
|
||||||
KrnlIterateOperandPack matmulPack(rewriter, matmulLoops,
|
KrnlIterateOperandPack matmulPack(
|
||||||
optimizedMatmulLoops);
|
rewriter, matmulLoops, optimizedMatmulLoops);
|
||||||
addDimensionToPack(rewriter, loc, matmulPack, alloc,
|
addDimensionToPack(
|
||||||
memRefShape.size() - 1);
|
rewriter, loc, matmulPack, alloc, memRefShape.size() - 1);
|
||||||
matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack);
|
matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -230,8 +229,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
||||||
std::vector<Value> optimizedReduceLoops;
|
std::vector<Value> optimizedReduceLoops;
|
||||||
Block *optimizationReduceBlock =
|
Block *optimizationReduceBlock =
|
||||||
defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1);
|
defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1);
|
||||||
KrnlIterateOperandPack reducePack(rewriter, reduceLoops,
|
KrnlIterateOperandPack reducePack(
|
||||||
optimizedReduceLoops);
|
rewriter, reduceLoops, optimizedReduceLoops);
|
||||||
addDimensionToPack(rewriter, loc, reducePack, A, AShape.size() - 1);
|
addDimensionToPack(rewriter, loc, reducePack, A, AShape.size() - 1);
|
||||||
auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack);
|
auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack);
|
||||||
|
|
||||||
|
@ -292,8 +291,8 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
||||||
std::vector<Value> optimizedReduceLoops;
|
std::vector<Value> optimizedReduceLoops;
|
||||||
Block *optimizationReduceBlock =
|
Block *optimizationReduceBlock =
|
||||||
defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1);
|
defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1);
|
||||||
KrnlIterateOperandPack reducePack(rewriter, reduceLoops,
|
KrnlIterateOperandPack reducePack(
|
||||||
optimizedReduceLoops);
|
rewriter, reduceLoops, optimizedReduceLoops);
|
||||||
addDimensionToPack(rewriter, loc, reducePack, A, 0);
|
addDimensionToPack(rewriter, loc, reducePack, A, 0);
|
||||||
auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack);
|
auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack);
|
||||||
|
|
||||||
|
|
|
@ -102,8 +102,7 @@ struct ONNXReductionOpLowering : public ConversionPattern {
|
||||||
ONNXReductionOpLowering(MLIRContext *ctx)
|
ONNXReductionOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
/*
|
/*
|
||||||
* Condition: reduction function must be associative and commutative.
|
* Condition: reduction function must be associative and commutative.
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
//===--------------- Conv.cpp - Lowering Convolution Op --------------------===//
|
//===--------------- Conv.cpp - Lowering Convolution Op
|
||||||
|
//--------------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -175,14 +176,12 @@ struct ONNXConvOpLowering : public ConversionPattern {
|
||||||
|
|
||||||
// Emit the bias, if needed.
|
// Emit the bias, if needed.
|
||||||
if (hasBias) {
|
if (hasBias) {
|
||||||
auto loadResult =
|
auto loadResult = rewriter.create<LoadOp>(loc, alloc, resultIndices);
|
||||||
rewriter.create<LoadOp>(loc, alloc, resultIndices);
|
|
||||||
SmallVector<Value, 4> biasIndices;
|
SmallVector<Value, 4> biasIndices;
|
||||||
biasIndices.emplace_back(kernel);
|
biasIndices.emplace_back(kernel);
|
||||||
auto loadBias =
|
auto loadBias = rewriter.create<LoadOp>(loc, biasOperand, kernel);
|
||||||
rewriter.create<LoadOp>(loc, biasOperand, kernel);
|
auto resultWithBias =
|
||||||
auto resultWithBias = rewriter.create<MulFOp>(
|
rewriter.create<MulFOp>(loc, loadResult, loadBias);
|
||||||
loc, loadResult, loadBias);
|
|
||||||
// Store initializer value into output location.
|
// Store initializer value into output location.
|
||||||
rewriter.create<StoreOp>(loc, resultWithBias, alloc, resultIndices);
|
rewriter.create<StoreOp>(loc, resultWithBias, alloc, resultIndices);
|
||||||
}
|
}
|
||||||
|
|
|
@ -459,7 +459,8 @@ struct ONNXPoolOpLowering : public ConversionPattern {
|
||||||
poolDimMap = AffineMap::get(1, 5, dimExpr, rewriter.getContext());
|
poolDimMap = AffineMap::get(1, 5, dimExpr, rewriter.getContext());
|
||||||
|
|
||||||
// poolStartMap and poolEndMap
|
// poolStartMap and poolEndMap
|
||||||
poolStartMap = AffineMap::get(1, 5, {start1, start2}, rewriter.getContext());
|
poolStartMap =
|
||||||
|
AffineMap::get(1, 5, {start1, start2}, rewriter.getContext());
|
||||||
poolEndMap = AffineMap::get(1, 5, {end1, end2}, rewriter.getContext());
|
poolEndMap = AffineMap::get(1, 5, {end1, end2}, rewriter.getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -36,9 +36,7 @@ MemRefType convertToMemRefType(Type type) {
|
||||||
|
|
||||||
/// Insert an allocation and deallocation for the given MemRefType.
|
/// Insert an allocation and deallocation for the given MemRefType.
|
||||||
Value insertAllocAndDealloc(MemRefType type, Location loc,
|
Value insertAllocAndDealloc(MemRefType type, Location loc,
|
||||||
PatternRewriter &rewriter,
|
PatternRewriter &rewriter, bool insertDealloc, ArrayRef<Value> operands) {
|
||||||
bool insertDealloc,
|
|
||||||
ArrayRef<Value> operands) {
|
|
||||||
// Put together alloc operands for any dynamic dimensions of the memref.
|
// Put together alloc operands for any dynamic dimensions of the memref.
|
||||||
AllocOp alloc;
|
AllocOp alloc;
|
||||||
if (!operands.empty()) {
|
if (!operands.empty()) {
|
||||||
|
@ -64,10 +62,10 @@ Value insertAllocAndDealloc(MemRefType type, Location loc,
|
||||||
auto operandDim =
|
auto operandDim =
|
||||||
rewriter.create<DimOp>(loc, operands[i], operandDimIdx);
|
rewriter.create<DimOp>(loc, operands[i], operandDimIdx);
|
||||||
if (maxDim) {
|
if (maxDim) {
|
||||||
auto maxCondition = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt,
|
auto maxCondition = rewriter.create<CmpIOp>(
|
||||||
operandDim, maxDim);
|
loc, CmpIPredicate::sgt, operandDim, maxDim);
|
||||||
maxDim = rewriter.create<SelectOp>(loc, maxCondition, operandDim,
|
maxDim = rewriter.create<SelectOp>(
|
||||||
maxDim);
|
loc, maxCondition, operandDim, maxDim);
|
||||||
} else {
|
} else {
|
||||||
maxDim = operandDim;
|
maxDim = operandDim;
|
||||||
}
|
}
|
||||||
|
@ -122,8 +120,8 @@ bool checkInsertDealloc(Operation *currentOp, int resultIndex) {
|
||||||
// Create a mapping from result type's dimensions to input type's dimensions,
|
// Create a mapping from result type's dimensions to input type's dimensions,
|
||||||
// given that the result type is the result of a reduction op over the input
|
// given that the result type is the result of a reduction op over the input
|
||||||
// type.
|
// type.
|
||||||
std::map<int64_t, int64_t>
|
std::map<int64_t, int64_t> getReductionMapping(
|
||||||
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
|
MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
|
||||||
std::map<int64_t, int64_t> OutInDimMap;
|
std::map<int64_t, int64_t> OutInDimMap;
|
||||||
int64_t rank = inputTy.getRank();
|
int64_t rank = inputTy.getRank();
|
||||||
|
|
||||||
|
@ -152,9 +150,8 @@ getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
|
||||||
|
|
||||||
// Add bounds associated with the op operand to the KRNL iteration pack.
|
// Add bounds associated with the op operand to the KRNL iteration pack.
|
||||||
// Dynamic dimenions are supported.
|
// Dynamic dimenions are supported.
|
||||||
void addDimensionToPack(ConversionPatternRewriter &rewriter,
|
void addDimensionToPack(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
Location loc, KrnlIterateOperandPack &pack,
|
KrnlIterateOperandPack &pack, Value operand, int index) {
|
||||||
Value operand, int index) {
|
|
||||||
auto shape = operand.getType().cast<MemRefType>().getShape();
|
auto shape = operand.getType().cast<MemRefType>().getShape();
|
||||||
if (shape[index] < 0) {
|
if (shape[index] < 0) {
|
||||||
pack.pushConstantBound(0);
|
pack.pushConstantBound(0);
|
||||||
|
@ -168,10 +165,9 @@ void addDimensionToPack(ConversionPatternRewriter &rewriter,
|
||||||
|
|
||||||
// Function that defines the KRNL dialect loops and their respective
|
// Function that defines the KRNL dialect loops and their respective
|
||||||
// optimized version.
|
// optimized version.
|
||||||
KrnlOptimizeLoopsOp
|
KrnlOptimizeLoopsOp emitOptimizedLoops(ConversionPatternRewriter &rewriter,
|
||||||
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
|
Location loc, std::vector<Value> &loops, std::vector<Value> &optimizedLoops,
|
||||||
std::vector<Value> &loops,
|
int64_t numLoops) {
|
||||||
std::vector<Value> &optimizedLoops, int64_t numLoops) {
|
|
||||||
// Define loops.
|
// Define loops.
|
||||||
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops);
|
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops);
|
||||||
loops.reserve(numLoops);
|
loops.reserve(numLoops);
|
||||||
|
@ -190,8 +186,7 @@ emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
// Function that emits the loops and their optimized version.
|
// Function that emits the loops and their optimized version.
|
||||||
// The function returns a reference to the inner optimization block.
|
// The function returns a reference to the inner optimization block.
|
||||||
Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
|
Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
std::vector<Value> &loops,
|
std::vector<Value> &loops, std::vector<Value> &optimizedLoops,
|
||||||
std::vector<Value> &optimizedLoops,
|
|
||||||
int64_t numLoops) {
|
int64_t numLoops) {
|
||||||
KrnlOptimizeLoopsOp optimizedLoopsOp =
|
KrnlOptimizeLoopsOp optimizedLoopsOp =
|
||||||
emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops);
|
emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops);
|
||||||
|
@ -201,10 +196,9 @@ Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
// Function which emits a basic set of loops and optimized loops
|
// Function which emits a basic set of loops and optimized loops
|
||||||
// for a given operation argument. A reference to the loop optimization
|
// for a given operation argument. A reference to the loop optimization
|
||||||
// block is returned in the last argument of the function.
|
// block is returned in the last argument of the function.
|
||||||
void emitKrnlLoopsAndIterationForOperand(
|
void emitKrnlLoopsAndIterationForOperand(ConversionPatternRewriter &rewriter,
|
||||||
ConversionPatternRewriter &rewriter, Location loc, Value operand,
|
Location loc, Value operand, std::vector<Value> &originalLoops,
|
||||||
std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
|
KrnlOptimizeLoopsOp &optimizedLoopsOp, KrnlIterateOp &iterateOp) {
|
||||||
KrnlIterateOp &iterateOp) {
|
|
||||||
// Operand shape.
|
// Operand shape.
|
||||||
auto shape = operand.getType().cast<MemRefType>().getShape();
|
auto shape = operand.getType().cast<MemRefType>().getShape();
|
||||||
|
|
||||||
|
@ -240,9 +234,9 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
|
||||||
|
|
||||||
// Get run-time dimension information for unknown dimensions used for
|
// Get run-time dimension information for unknown dimensions used for
|
||||||
// broadcasting.
|
// broadcasting.
|
||||||
std::map<int, std::map<int, Value>>
|
std::map<int, std::map<int, Value>> getBroadcastedDimInfo(Location loc,
|
||||||
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter, MemRefType memRefType,
|
||||||
MemRefType memRefType, ArrayRef<Value> operands) {
|
ArrayRef<Value> operands) {
|
||||||
auto memRefShape = memRefType.getShape();
|
auto memRefShape = memRefType.getShape();
|
||||||
int64_t rank = memRefShape.size();
|
int64_t rank = memRefShape.size();
|
||||||
// For unknown dimensions, we need to get dimension values at runtime in
|
// For unknown dimensions, we need to get dimension values at runtime in
|
||||||
|
@ -286,9 +280,8 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
|
|
||||||
// Extract induction variables that are used for broadcasting values of a
|
// Extract induction variables that are used for broadcasting values of a
|
||||||
// given operand.
|
// given operand.
|
||||||
std::vector<Value>
|
std::vector<Value> getLoopIVsForBroadcasting(Location loc,
|
||||||
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter, ArrayRef<Value> loopIVs, Value operand,
|
||||||
ArrayRef<Value> loopIVs, Value operand,
|
|
||||||
std::map<int, Value> broadcastedDims) {
|
std::map<int, Value> broadcastedDims) {
|
||||||
// `operand` must has a ranked type. This should have been checked by the
|
// `operand` must has a ranked type. This should have been checked by the
|
||||||
// shape inference pass.
|
// shape inference pass.
|
||||||
|
@ -310,8 +303,8 @@ getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
// If its value is 1, it is broadcasted dimension.
|
// If its value is 1, it is broadcasted dimension.
|
||||||
// Otherwise, non-broadcasted dimension.
|
// Otherwise, non-broadcasted dimension.
|
||||||
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||||
auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], zero,
|
auto idx = rewriter.create<SelectOp>(
|
||||||
loopIVs[loopIdx]);
|
loc, broadcastedDims[dimIdx], zero, loopIVs[loopIdx]);
|
||||||
newLoopIVs.insert(newLoopIVs.begin(), idx);
|
newLoopIVs.insert(newLoopIVs.begin(), idx);
|
||||||
} else {
|
} else {
|
||||||
// Non-broadcasted dimension
|
// Non-broadcasted dimension
|
||||||
|
|
|
@ -38,8 +38,7 @@ struct ONNXConstantOpLowering : public ConversionPattern {
|
||||||
numElements *= shape[i];
|
numElements *= shape[i];
|
||||||
|
|
||||||
// Emit the constant global in Krnl dialect.
|
// Emit the constant global in Krnl dialect.
|
||||||
auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc,
|
auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc, memRefType,
|
||||||
memRefType,
|
|
||||||
rewriter.getI64ArrayAttr(shape),
|
rewriter.getI64ArrayAttr(shape),
|
||||||
rewriter.getStringAttr("constant_" + std::to_string(constantID)),
|
rewriter.getStringAttr("constant_" + std::to_string(constantID)),
|
||||||
constantOp.value().getValue());
|
constantOp.value().getValue());
|
||||||
|
|
|
@ -39,7 +39,6 @@ MLONNXOpsDialect::MLONNXOpsDialect(mlir::MLIRContext *ctx)
|
||||||
>();
|
>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TableGen'd op method definitions
|
// TableGen'd op method definitions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -19,8 +19,8 @@
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
|
||||||
#include "src/Interface/ShapeInferenceInterface.hpp"
|
|
||||||
#include "src/Interface/PromotableConstOperandsOpInterface.hpp"
|
#include "src/Interface/PromotableConstOperandsOpInterface.hpp"
|
||||||
|
#include "src/Interface/ShapeInferenceInterface.hpp"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
|
|
|
@ -19,8 +19,8 @@
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
|
||||||
#include "src/Interface/ShapeInferenceInterface.hpp"
|
|
||||||
#include "src/Interface/PromotableConstOperandsOpInterface.hpp"
|
#include "src/Interface/PromotableConstOperandsOpInterface.hpp"
|
||||||
|
#include "src/Interface/ShapeInferenceInterface.hpp"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,6 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
|
||||||
#include "src/Interface/PromotableConstOperandsOpInterface.hpp"
|
#include "src/Interface/PromotableConstOperandsOpInterface.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
@ -20,4 +19,3 @@ using namespace mlir;
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "src/Interface/PromotableConstOperandsOpInterface.cpp.inc"
|
#include "src/Interface/PromotableConstOperandsOpInterface.cpp.inc"
|
||||||
|
|
||||||
|
|
|
@ -46,10 +46,10 @@ void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
|
||||||
void EmitLLVMBitCode(
|
void EmitLLVMBitCode(
|
||||||
const mlir::OwningModuleRef &module, string outputFilename) {
|
const mlir::OwningModuleRef &module, string outputFilename) {
|
||||||
error_code error;
|
error_code error;
|
||||||
llvm::raw_fd_ostream moduleBitcodeStream(outputFilename, error,
|
llvm::raw_fd_ostream moduleBitcodeStream(
|
||||||
llvm::sys::fs::F_None);
|
outputFilename, error, llvm::sys::fs::F_None);
|
||||||
llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module),
|
llvm::WriteBitcodeToFile(
|
||||||
moduleBitcodeStream);
|
*mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream);
|
||||||
moduleBitcodeStream.flush();
|
moduleBitcodeStream.flush();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -99,7 +99,6 @@ void processInputFile(string inputFilename, EmissionTargetType emissionTarget,
|
||||||
assert(inputIsONNX != inputIsMLIR &&
|
assert(inputIsONNX != inputIsMLIR &&
|
||||||
"Either ONNX model or MLIR file needs to be provided.");
|
"Either ONNX model or MLIR file needs to be provided.");
|
||||||
|
|
||||||
|
|
||||||
if (inputIsONNX) {
|
if (inputIsONNX) {
|
||||||
ImportFrontendModelFile(inputFilename, context, module);
|
ImportFrontendModelFile(inputFilename, context, module);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -28,9 +28,9 @@
|
||||||
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
||||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||||
#include "mlir/InitAllDialects.h"
|
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
#include "mlir/IR/Module.h"
|
#include "mlir/IR/Module.h"
|
||||||
|
#include "mlir/InitAllDialects.h"
|
||||||
#include "mlir/Parser.h"
|
#include "mlir/Parser.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
|
@ -66,8 +66,7 @@ void processInputFile(std::string inputFilename,
|
||||||
mlir::OwningModuleRef &module);
|
mlir::OwningModuleRef &module);
|
||||||
|
|
||||||
void outputCode(
|
void outputCode(
|
||||||
mlir::OwningModuleRef &module, std::string filename,
|
mlir::OwningModuleRef &module, std::string filename, std::string extension);
|
||||||
std::string extension);
|
|
||||||
|
|
||||||
void emitOutputFiles(std::string outputBaseName,
|
void emitOutputFiles(std::string outputBaseName,
|
||||||
EmissionTargetType emissionTarget, mlir::MLIRContext &context,
|
EmissionTargetType emissionTarget, mlir::MLIRContext &context,
|
||||||
|
|
|
@ -33,8 +33,8 @@ DynMemRef *getDynMemRef(OrderedDynMemRefDict *tensorDict, int idx) {
|
||||||
return tensorDict->tensorDict[tensorDict->orderedNames[idx]];
|
return tensorDict->tensorDict[tensorDict->orderedNames[idx]];
|
||||||
}
|
}
|
||||||
|
|
||||||
void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx,
|
void setDynMemRef(
|
||||||
DynMemRef *tensor) {
|
OrderedDynMemRefDict *tensorDict, int idx, DynMemRef *tensor) {
|
||||||
if (tensorDict->orderedNames.size() <= idx)
|
if (tensorDict->orderedNames.size() <= idx)
|
||||||
tensorDict->orderedNames.resize(idx + 1);
|
tensorDict->orderedNames.resize(idx + 1);
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,6 @@ typedef struct _OrderedDynMemRefDict OrderedDynMemRefDict;
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
// Get number of dynamic memrefs in OrderedDynMemRefDict dict.
|
// Get number of dynamic memrefs in OrderedDynMemRefDict dict.
|
||||||
int numDynMemRefs(OrderedDynMemRefDict *dict);
|
int numDynMemRefs(OrderedDynMemRefDict *dict);
|
||||||
|
|
||||||
|
@ -53,8 +52,8 @@ DynMemRef *createDynMemRef(int rank);
|
||||||
DynMemRef *getDynMemRef(OrderedDynMemRefDict *orderedDict, int i);
|
DynMemRef *getDynMemRef(OrderedDynMemRefDict *orderedDict, int i);
|
||||||
|
|
||||||
// Set the i-th dynmemref in orderedDict to be dynMemRef.
|
// Set the i-th dynmemref in orderedDict to be dynMemRef.
|
||||||
void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx,
|
void setDynMemRef(
|
||||||
DynMemRef *dynMemRef);
|
OrderedDynMemRefDict *tensorDict, int idx, DynMemRef *dynMemRef);
|
||||||
|
|
||||||
// Get data pointer from dynMemRef.
|
// Get data pointer from dynMemRef.
|
||||||
void *getData(DynMemRef *dynMemRef);
|
void *getData(DynMemRef *dynMemRef);
|
||||||
|
|
|
@ -1,14 +1,14 @@
|
||||||
#include "Runtime.hpp"
|
#include "Runtime.hpp"
|
||||||
|
|
||||||
ExecutionSession::ExecutionSession(std::string sharedLibPath,
|
ExecutionSession::ExecutionSession(
|
||||||
std::string entryPointName) {
|
std::string sharedLibPath, std::string entryPointName) {
|
||||||
_sharedLibraryHandle = dlopen(sharedLibPath.c_str(), RTLD_LAZY);
|
_sharedLibraryHandle = dlopen(sharedLibPath.c_str(), RTLD_LAZY);
|
||||||
_entryPointFunc =
|
_entryPointFunc =
|
||||||
(entryPointFuncType)dlsym(_sharedLibraryHandle, entryPointName.c_str());
|
(entryPointFuncType)dlsym(_sharedLibraryHandle, entryPointName.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<py::array>
|
std::vector<py::array> ExecutionSession::run(
|
||||||
ExecutionSession::run(std::vector<py::array> inputsPyArray) {
|
std::vector<py::array> inputsPyArray) {
|
||||||
assert(_entryPointFunc && "entry point not loaded");
|
assert(_entryPointFunc && "entry point not loaded");
|
||||||
auto *wrappedInput = createOrderedDynMemRefDict();
|
auto *wrappedInput = createOrderedDynMemRefDict();
|
||||||
int inputIdx = 0;
|
int inputIdx = 0;
|
||||||
|
@ -40,8 +40,8 @@ ExecutionSession::run(std::vector<py::array> inputsPyArray) {
|
||||||
auto *wrappedOutput = _entryPointFunc(wrappedInput);
|
auto *wrappedOutput = _entryPointFunc(wrappedInput);
|
||||||
for (int i = 0; i < numDynMemRefs(wrappedOutput); i++) {
|
for (int i = 0; i < numDynMemRefs(wrappedOutput); i++) {
|
||||||
auto *dynMemRef = getDynMemRef(wrappedOutput, i);
|
auto *dynMemRef = getDynMemRef(wrappedOutput, i);
|
||||||
auto shape = std::vector<int64_t>(dynMemRef->sizes,
|
auto shape = std::vector<int64_t>(
|
||||||
dynMemRef->sizes + dynMemRef->rank);
|
dynMemRef->sizes, dynMemRef->sizes + dynMemRef->rank);
|
||||||
outputPyArrays.emplace_back(
|
outputPyArrays.emplace_back(
|
||||||
py::array(py::dtype("float32"), shape, dynMemRef->data));
|
py::array(py::dtype("float32"), shape, dynMemRef->data));
|
||||||
}
|
}
|
||||||
|
|
|
@ -144,9 +144,9 @@ public:
|
||||||
|
|
||||||
assert(krnlGlobalOp.value().hasValue() &&
|
assert(krnlGlobalOp.value().hasValue() &&
|
||||||
"Krnl Global must always have a value");
|
"Krnl Global must always have a value");
|
||||||
global = rewriter.create<LLVM::GlobalOp>(loc,
|
global = rewriter.create<LLVM::GlobalOp>(loc, llvmGlobalType,
|
||||||
llvmGlobalType, /*isConstant=*/true,
|
/*isConstant=*/true, LLVM::Linkage::Internal, name,
|
||||||
LLVM::Linkage::Internal, name, krnlGlobalOp.value().getValue());
|
krnlGlobalOp.value().getValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Some frequently used types.
|
// Some frequently used types.
|
||||||
|
|
|
@ -12,8 +12,8 @@
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
#include <numeric>
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,8 @@ namespace {
|
||||||
/// Include the patterns defined in the Declarative Rewrite framework.
|
/// Include the patterns defined in the Declarative Rewrite framework.
|
||||||
#include "src/Transform/ONNX/ONNXDecompose.inc"
|
#include "src/Transform/ONNX/ONNXDecompose.inc"
|
||||||
|
|
||||||
struct DecomposeONNXToONNXPass : public PassWrapper<DecomposeONNXToONNXPass, FunctionPass> {
|
struct DecomposeONNXToONNXPass
|
||||||
|
: public PassWrapper<DecomposeONNXToONNXPass, FunctionPass> {
|
||||||
void runOnFunction() final;
|
void runOnFunction() final;
|
||||||
};
|
};
|
||||||
} // end anonymous namespace.
|
} // end anonymous namespace.
|
||||||
|
|
|
@ -9,10 +9,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/IR/StandardTypes.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "llvm/ADT/SmallPtrSet.h"
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "mlir/IR/StandardTypes.h"
|
|
||||||
|
|
||||||
#include "src/Interface/ShapeInferenceInterface.hpp"
|
#include "src/Interface/ShapeInferenceInterface.hpp"
|
||||||
#include "src/Pass/Passes.hpp"
|
#include "src/Pass/Passes.hpp"
|
||||||
|
@ -25,7 +25,8 @@ namespace {
|
||||||
* candidate operations and propagating the shape information until the list
|
* candidate operations and propagating the shape information until the list
|
||||||
* of operations is empty [credit MLIR authors].
|
* of operations is empty [credit MLIR authors].
|
||||||
*/
|
*/
|
||||||
class ShapeInferencePass : public mlir::PassWrapper<ShapeInferencePass, mlir::FunctionPass> {
|
class ShapeInferencePass
|
||||||
|
: public mlir::PassWrapper<ShapeInferencePass, mlir::FunctionPass> {
|
||||||
public:
|
public:
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
auto f = getFunction();
|
auto f = getFunction();
|
||||||
|
@ -63,8 +64,7 @@ public:
|
||||||
|
|
||||||
if (auto terminator_op = f.getBody().back().getTerminator()) {
|
if (auto terminator_op = f.getBody().back().getTerminator()) {
|
||||||
auto results = terminator_op->getOperandTypes();
|
auto results = terminator_op->getOperandTypes();
|
||||||
f.setType(FunctionType::get(
|
f.setType(FunctionType::get(f.getType().getInputs(),
|
||||||
f.getType().getInputs(),
|
|
||||||
std::vector<Type>(results.begin(), results.end()), f.getContext()));
|
std::vector<Type>(results.begin(), results.end()), f.getContext()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -146,5 +146,5 @@ std::unique_ptr<mlir::Pass> mlir::createShapeInferencePass() {
|
||||||
return std::make_unique<ShapeInferencePass>();
|
return std::make_unique<ShapeInferencePass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<ShapeInferencePass>
|
static PassRegistration<ShapeInferencePass> pass(
|
||||||
pass("shape-inference", "Shape inference for frontend dialects.");
|
"shape-inference", "Shape inference for frontend dialects.");
|
||||||
|
|
20
src/main.cpp
20
src/main.cpp
|
@ -14,10 +14,10 @@ using namespace onnx_mlir;
|
||||||
int main(int argc, char *argv[]) {
|
int main(int argc, char *argv[]) {
|
||||||
registerDialects();
|
registerDialects();
|
||||||
|
|
||||||
llvm::cl::OptionCategory OnnxMlirOptions("ONNX MLIR Options",
|
llvm::cl::OptionCategory OnnxMlirOptions(
|
||||||
"These are frontend options.");
|
"ONNX MLIR Options", "These are frontend options.");
|
||||||
llvm::cl::opt<string> inputFilename(
|
llvm::cl::opt<string> inputFilename(llvm::cl::Positional,
|
||||||
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"),
|
llvm::cl::desc("<input file>"), llvm::cl::init("-"),
|
||||||
llvm::cl::cat(OnnxMlirOptions));
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
llvm::cl::opt<EmissionTargetType> emissionTarget(
|
llvm::cl::opt<EmissionTargetType> emissionTarget(
|
||||||
|
@ -26,18 +26,18 @@ int main(int argc, char *argv[]) {
|
||||||
clEnumVal(EmitONNXBasic,
|
clEnumVal(EmitONNXBasic,
|
||||||
"Ingest ONNX and emit the basic ONNX operations without"
|
"Ingest ONNX and emit the basic ONNX operations without"
|
||||||
"inferred shapes."),
|
"inferred shapes."),
|
||||||
clEnumVal(EmitONNXIR,
|
clEnumVal(
|
||||||
"Ingest ONNX and emit corresponding ONNX dialect."),
|
EmitONNXIR, "Ingest ONNX and emit corresponding ONNX dialect."),
|
||||||
clEnumVal(EmitMLIR,
|
clEnumVal(
|
||||||
"Lower model to MLIR built-in transformation dialect."),
|
EmitMLIR, "Lower model to MLIR built-in transformation dialect."),
|
||||||
clEnumVal(EmitLLVMIR, "Lower model to LLVM IR (LLVM dialect)."),
|
clEnumVal(EmitLLVMIR, "Lower model to LLVM IR (LLVM dialect)."),
|
||||||
clEnumVal(EmitLLVMBC, "Lower model to LLVM IR and emit (to file) "
|
clEnumVal(EmitLLVMBC, "Lower model to LLVM IR and emit (to file) "
|
||||||
"LLVM bitcode for model.")),
|
"LLVM bitcode for model.")),
|
||||||
llvm::cl::init(EmitLLVMBC), llvm::cl::cat(OnnxMlirOptions));
|
llvm::cl::init(EmitLLVMBC), llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
llvm::cl::HideUnrelatedOptions(OnnxMlirOptions);
|
llvm::cl::HideUnrelatedOptions(OnnxMlirOptions);
|
||||||
llvm::cl::ParseCommandLineOptions(argc, argv,
|
llvm::cl::ParseCommandLineOptions(
|
||||||
"ONNX MLIR modular optimizer driver\n");
|
argc, argv, "ONNX MLIR modular optimizer driver\n");
|
||||||
|
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
mlir::OwningModuleRef module;
|
mlir::OwningModuleRef module;
|
||||||
|
|
Loading…
Reference in New Issue