Support attribute promotion.

This commit is contained in:
Tian Jin 2020-03-17 17:39:34 +08:00
parent d86591d61a
commit 955968b750
11 changed files with 178 additions and 6 deletions

View File

@ -12,7 +12,10 @@ add_library(compiler
pass/onnx_combine.cpp
pass/onnx_rewrite.cpp
pass/onnx_decompose.cpp
pass/passes.hpp)
pass/passes.hpp
dialect/onnx/const_operands_interface.hpp
dialect/onnx/attribute_promotion.cpp
)
# Include root src directory.
target_include_directories(compiler PRIVATE ${ONNF_SRC_ROOT})
@ -47,12 +50,18 @@ onnf_tablegen(onnx_rewrite.inc -gen-rewriters)
add_public_tablegen_target(gen_onnx_rewrite)
add_dependencies(compiler gen_onnx_rewrite)
add_subdirectory(dialect/onnx)
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")
set(GEN_DOC_FILE ${CMAKE_BINARY_DIR}/docs/Dialects/onnx.md)
add_public_tablegen_target(gen_onnx)
add_dependencies(gen_onnx gen_shape_inference)
add_dependencies(gen_onnx gen_const_operands_interface)
add_dependencies(compiler gen_onnx)
add_onnf_dialect_doc(onnx dialect/onnx/onnx.td)
set(LLVM_TARGET_DEFINITIONS dialect/krnl/krnl_ops.td)
@ -66,7 +75,7 @@ target_include_directories(onnf_onnx_decompose
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
${ONNF_SRC_ROOT})
target_link_libraries(onnf_onnx_decompose ${MLIRLibs})
add_dependencies(onnf_onnx_decompose gen_krnl_ops)
add_dependencies(onnf_onnx_decompose gen_onnx)
add_library(onnf_shape_inference pass/shape_inference_pass.cpp)
target_include_directories(onnf_shape_inference

View File

@ -0,0 +1,11 @@
#add_library(onnf_fold_const_operand_to_attribute fold_const_operand_to_attribute.hpp)
#target_include_directories(onnf_fold_const_operand_to_attribute
# PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT})
#target_link_libraries(onnf_fold_const_operand_to_attribute ${MLIRLibs})
set(LLVM_TARGET_DEFINITIONS const_operands_interface.td)
onnf_tablegen(const_operands_interface.hpp.inc -gen-op-interface-decls)
onnf_tablegen(const_operands_interface.cpp.inc -gen-op-interface-defs)
message("TBLGENOUT" ${TABLEGEN_OUTPUT})
add_public_tablegen_target(gen_const_operands_interface)
add_dependencies(compiler gen_const_operands_interface)

View File

@ -0,0 +1,79 @@
//===----- shape_inference_pass.cpp - Shape Inference ---------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file implements a Function level pass performing propagation of array
// shapes through function specialization.
//
//===----------------------------------------------------------------------===//
#include "mlir/Pass/Pass.h"
#include "llvm/Support/raw_ostream.h"
#include "src/dialect/onnx/const_operands_interface.hpp"
#include "src/dialect/onnx/onnx_ops.hpp"
#include "src/pass/passes.hpp"
using namespace mlir;
#include "src/dialect/onnx/const_operands_interface.cpp.inc"
namespace {
/*!
* FunctionPass that performs shape inference by iterating over a list of
* candidate operations and propagating the shape information until the list
* of operations is empty [credit MLIR authors].
*/
class AttributePromotionPass
: public mlir::FunctionPass<AttributePromotionPass> {
public:
void runOnFunction() override {
auto f = getFunction();
llvm::Optional<mlir::Value> none;
f.walk([&](mlir::Operation *op) {
if (op->getNumResults() && op->getOpResult(0).getType().isa<NoneType>())
none = op->getOpResult(0);
});
if (!none) {
OpBuilder builder(f.getContext());
builder.setInsertionPointToStart(&f.front());
none =
builder.create<mlir::ConstantOp>(f.getLoc(), builder.getUnitAttr());
}
// Iterate on the operations that need shape inference i.e the operations
// that return a dynamic shape.
f.walk([&](mlir::Operation *op) {
if (IdentifyConstOperandsOpInterface opWithConstOperand =
dyn_cast<IdentifyConstOperandsOpInterface>(op)) {
auto promotableOperands = opWithConstOperand.identifyPromotableConstOperands();
for (const auto& operandNameToIdx : promotableOperands) {
auto name = operandNameToIdx.first;
auto idx = operandNameToIdx.second;
auto operandToPromote = op->getOperand(idx);
if (auto constantOp =
dyn_cast_or_null<ConstantOp>(operandToPromote.getDefiningOp())) {
op->setAttr(name, constantOp.value());
op->setOperand(idx, *none);
}
}
}
});
}
};
} // end anonymous namespace
/*!
* Create a Shape Inference pass.
*/
std::unique_ptr<mlir::Pass> mlir::createAttributePromotionPass() {
return std::make_unique<AttributePromotionPass>();
}
static PassRegistration<AttributePromotionPass> pass(
"attribute-promotion", "Shape inference for frontend dialects.");

View File

@ -0,0 +1,24 @@
//===- shape_inference_interface.hpp - Definition for ShapeInference --------=//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file contains the declarations of the shape inference interfaces defined
// in ShapeInferenceInterface.td.
//
//===----------------------------------------------------------------------===//
#pragma once
#include <map>
#include <string>
#include "mlir/IR/OpDefinition.h"
namespace mlir {
/// Include the auto-generated declarations.
#include "src/dialect/onnx/const_operands_interface.hpp.inc"
} // end namespace mlir

View File

@ -0,0 +1,32 @@
//=- shape_inference_interface.td - Shape Inference Interface -*- tablegen -==//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// Defines the operations of the Shape Inference Op Interface.
//
//===----------------------------------------------------------------------===//
#ifdef IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
#else
#define IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
def IdentifyConstOperandsOpInterface : OpInterface<"IdentifyConstOperandsOpInterface"> {
let description = [{
Interface to access a registered method to infer the return types for an
operation that can be used during type inference.
}];
let methods = [
InterfaceMethod<"Infer and set the output shape for the current operation.",
"std::map<std::string, size_t>", "identifyPromotableConstOperands">
];
}
#endif // IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE

View File

@ -22,6 +22,11 @@ include "mlir/IR/OpBase.td"
include "pass/shape_inference_interface.td"
#endif // SHAPE_INFERENCE_INTERFACE
#ifdef IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
#else
include "dialect/onnx/const_operands_interface.td"
#endif // IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
def ONNX_Dialect : Dialect {
let name = "onnx";
let cppNamespace = "";

View File

@ -10,6 +10,9 @@
#pragma once
#include <map>
#include <string>
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
@ -17,6 +20,7 @@
#include "mlir/IR/StandardTypes.h"
#include "src/pass/shape_inference_interface.hpp"
#include "src/dialect/onnx/const_operands_interface.hpp"
namespace mlir {

View File

@ -26,7 +26,7 @@ def ONNXAbsOp:ONNX_Op<"Abs",
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
build(builder, state, outputTypes, operands, attributes);
}]>
];
];
}
def ONNXAcosOp:ONNX_Op<"Acos",
@ -2501,7 +2501,9 @@ def ONNXReluOp:ONNX_Op<"Relu",
}
def ONNXReshapeOp:ONNX_Op<"Reshape",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
[NoSideEffect,
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
OpInterface<"IdentifyConstOperandsOpInterface">]> {
let summary = "ONNX Reshape operation";
let description = [{
"Reshape the input tensor similar to numpy.reshape."
@ -2512,8 +2514,11 @@ def ONNXReshapeOp:ONNX_Op<"Reshape",
"from the input tensor)."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$shape);
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$shape);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reshaped);
let extraClassDeclaration = [{
std::map<std::string, size_t> identifyPromotableConstOperands() { return {{"shape", 1}}; }
}];
}
def ONNXResizeOp:ONNX_Op<"Resize",

View File

@ -126,6 +126,7 @@ int main(int argc, char *argv[]) {
pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createAttributePromotionPass());
if (emissionTarget >= EmitMLIR) {
pm.addPass(mlir::createLowerToKrnlPass());

View File

@ -20,6 +20,8 @@ std::unique_ptr<Pass> createDecomposeONNXToONNXPass();
std::unique_ptr<Pass> createShapeInferencePass();
std::unique_ptr<Pass> createAttributePromotionPass();
/// Add pass for lowering to Krnl IR.
std::unique_ptr<Pass> createLowerToKrnlPass();

View File

@ -12,9 +12,9 @@
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/StandardTypes.h"
#include "shape_inference_interface.hpp"
#include "src/dialect/onnx/onnx_ops.hpp"
#include "passes.hpp"