Support attribute promotion.
This commit is contained in:
parent
d86591d61a
commit
955968b750
|
@ -12,7 +12,10 @@ add_library(compiler
|
||||||
pass/onnx_combine.cpp
|
pass/onnx_combine.cpp
|
||||||
pass/onnx_rewrite.cpp
|
pass/onnx_rewrite.cpp
|
||||||
pass/onnx_decompose.cpp
|
pass/onnx_decompose.cpp
|
||||||
pass/passes.hpp)
|
pass/passes.hpp
|
||||||
|
dialect/onnx/const_operands_interface.hpp
|
||||||
|
dialect/onnx/attribute_promotion.cpp
|
||||||
|
)
|
||||||
|
|
||||||
# Include root src directory.
|
# Include root src directory.
|
||||||
target_include_directories(compiler PRIVATE ${ONNF_SRC_ROOT})
|
target_include_directories(compiler PRIVATE ${ONNF_SRC_ROOT})
|
||||||
|
@ -47,12 +50,18 @@ onnf_tablegen(onnx_rewrite.inc -gen-rewriters)
|
||||||
add_public_tablegen_target(gen_onnx_rewrite)
|
add_public_tablegen_target(gen_onnx_rewrite)
|
||||||
add_dependencies(compiler gen_onnx_rewrite)
|
add_dependencies(compiler gen_onnx_rewrite)
|
||||||
|
|
||||||
|
add_subdirectory(dialect/onnx)
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
|
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
|
||||||
onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||||
onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||||
set(GEN_DOC_FILE ${CMAKE_BINARY_DIR}/docs/Dialects/onnx.md)
|
set(GEN_DOC_FILE ${CMAKE_BINARY_DIR}/docs/Dialects/onnx.md)
|
||||||
add_public_tablegen_target(gen_onnx)
|
add_public_tablegen_target(gen_onnx)
|
||||||
|
|
||||||
|
add_dependencies(gen_onnx gen_shape_inference)
|
||||||
|
add_dependencies(gen_onnx gen_const_operands_interface)
|
||||||
add_dependencies(compiler gen_onnx)
|
add_dependencies(compiler gen_onnx)
|
||||||
|
|
||||||
add_onnf_dialect_doc(onnx dialect/onnx/onnx.td)
|
add_onnf_dialect_doc(onnx dialect/onnx/onnx.td)
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS dialect/krnl/krnl_ops.td)
|
set(LLVM_TARGET_DEFINITIONS dialect/krnl/krnl_ops.td)
|
||||||
|
@ -66,7 +75,7 @@ target_include_directories(onnf_onnx_decompose
|
||||||
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
||||||
${ONNF_SRC_ROOT})
|
${ONNF_SRC_ROOT})
|
||||||
target_link_libraries(onnf_onnx_decompose ${MLIRLibs})
|
target_link_libraries(onnf_onnx_decompose ${MLIRLibs})
|
||||||
add_dependencies(onnf_onnx_decompose gen_krnl_ops)
|
add_dependencies(onnf_onnx_decompose gen_onnx)
|
||||||
|
|
||||||
add_library(onnf_shape_inference pass/shape_inference_pass.cpp)
|
add_library(onnf_shape_inference pass/shape_inference_pass.cpp)
|
||||||
target_include_directories(onnf_shape_inference
|
target_include_directories(onnf_shape_inference
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
#add_library(onnf_fold_const_operand_to_attribute fold_const_operand_to_attribute.hpp)
|
||||||
|
#target_include_directories(onnf_fold_const_operand_to_attribute
|
||||||
|
# PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT})
|
||||||
|
#target_link_libraries(onnf_fold_const_operand_to_attribute ${MLIRLibs})
|
||||||
|
|
||||||
|
set(LLVM_TARGET_DEFINITIONS const_operands_interface.td)
|
||||||
|
onnf_tablegen(const_operands_interface.hpp.inc -gen-op-interface-decls)
|
||||||
|
onnf_tablegen(const_operands_interface.cpp.inc -gen-op-interface-defs)
|
||||||
|
message("TBLGENOUT" ${TABLEGEN_OUTPUT})
|
||||||
|
add_public_tablegen_target(gen_const_operands_interface)
|
||||||
|
add_dependencies(compiler gen_const_operands_interface)
|
|
@ -0,0 +1,79 @@
|
||||||
|
//===----- shape_inference_pass.cpp - Shape Inference ---------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file implements a Function level pass performing propagation of array
|
||||||
|
// shapes through function specialization.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
#include "src/dialect/onnx/const_operands_interface.hpp"
|
||||||
|
#include "src/dialect/onnx/onnx_ops.hpp"
|
||||||
|
#include "src/pass/passes.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
#include "src/dialect/onnx/const_operands_interface.cpp.inc"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
/*!
|
||||||
|
* FunctionPass that performs shape inference by iterating over a list of
|
||||||
|
* candidate operations and propagating the shape information until the list
|
||||||
|
* of operations is empty [credit MLIR authors].
|
||||||
|
*/
|
||||||
|
class AttributePromotionPass
|
||||||
|
: public mlir::FunctionPass<AttributePromotionPass> {
|
||||||
|
public:
|
||||||
|
void runOnFunction() override {
|
||||||
|
auto f = getFunction();
|
||||||
|
|
||||||
|
llvm::Optional<mlir::Value> none;
|
||||||
|
f.walk([&](mlir::Operation *op) {
|
||||||
|
if (op->getNumResults() && op->getOpResult(0).getType().isa<NoneType>())
|
||||||
|
none = op->getOpResult(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!none) {
|
||||||
|
OpBuilder builder(f.getContext());
|
||||||
|
builder.setInsertionPointToStart(&f.front());
|
||||||
|
none =
|
||||||
|
builder.create<mlir::ConstantOp>(f.getLoc(), builder.getUnitAttr());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterate on the operations that need shape inference i.e the operations
|
||||||
|
// that return a dynamic shape.
|
||||||
|
f.walk([&](mlir::Operation *op) {
|
||||||
|
if (IdentifyConstOperandsOpInterface opWithConstOperand =
|
||||||
|
dyn_cast<IdentifyConstOperandsOpInterface>(op)) {
|
||||||
|
auto promotableOperands = opWithConstOperand.identifyPromotableConstOperands();
|
||||||
|
for (const auto& operandNameToIdx : promotableOperands) {
|
||||||
|
auto name = operandNameToIdx.first;
|
||||||
|
auto idx = operandNameToIdx.second;
|
||||||
|
|
||||||
|
auto operandToPromote = op->getOperand(idx);
|
||||||
|
if (auto constantOp =
|
||||||
|
dyn_cast_or_null<ConstantOp>(operandToPromote.getDefiningOp())) {
|
||||||
|
op->setAttr(name, constantOp.value());
|
||||||
|
op->setOperand(idx, *none);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* Create a Shape Inference pass.
|
||||||
|
*/
|
||||||
|
std::unique_ptr<mlir::Pass> mlir::createAttributePromotionPass() {
|
||||||
|
return std::make_unique<AttributePromotionPass>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<AttributePromotionPass> pass(
|
||||||
|
"attribute-promotion", "Shape inference for frontend dialects.");
|
|
@ -0,0 +1,24 @@
|
||||||
|
//===- shape_inference_interface.hpp - Definition for ShapeInference --------=//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file contains the declarations of the shape inference interfaces defined
|
||||||
|
// in ShapeInferenceInterface.td.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
/// Include the auto-generated declarations.
|
||||||
|
#include "src/dialect/onnx/const_operands_interface.hpp.inc"
|
||||||
|
|
||||||
|
} // end namespace mlir
|
|
@ -0,0 +1,32 @@
|
||||||
|
//=- shape_inference_interface.td - Shape Inference Interface -*- tablegen -==//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// Defines the operations of the Shape Inference Op Interface.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifdef IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
||||||
|
#else
|
||||||
|
#define IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
||||||
|
|
||||||
|
#ifdef OP_BASE
|
||||||
|
#else
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
#endif // OP_BASE
|
||||||
|
|
||||||
|
def IdentifyConstOperandsOpInterface : OpInterface<"IdentifyConstOperandsOpInterface"> {
|
||||||
|
let description = [{
|
||||||
|
Interface to access a registered method to infer the return types for an
|
||||||
|
operation that can be used during type inference.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let methods = [
|
||||||
|
InterfaceMethod<"Infer and set the output shape for the current operation.",
|
||||||
|
"std::map<std::string, size_t>", "identifyPromotableConstOperands">
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
|
@ -22,6 +22,11 @@ include "mlir/IR/OpBase.td"
|
||||||
include "pass/shape_inference_interface.td"
|
include "pass/shape_inference_interface.td"
|
||||||
#endif // SHAPE_INFERENCE_INTERFACE
|
#endif // SHAPE_INFERENCE_INTERFACE
|
||||||
|
|
||||||
|
#ifdef IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
||||||
|
#else
|
||||||
|
include "dialect/onnx/const_operands_interface.td"
|
||||||
|
#endif // IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
||||||
|
|
||||||
def ONNX_Dialect : Dialect {
|
def ONNX_Dialect : Dialect {
|
||||||
let name = "onnx";
|
let name = "onnx";
|
||||||
let cppNamespace = "";
|
let cppNamespace = "";
|
||||||
|
|
|
@ -10,6 +10,9 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
|
@ -17,6 +20,7 @@
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
|
||||||
#include "src/pass/shape_inference_interface.hpp"
|
#include "src/pass/shape_inference_interface.hpp"
|
||||||
|
#include "src/dialect/onnx/const_operands_interface.hpp"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
|
|
|
@ -2501,7 +2501,9 @@ def ONNXReluOp:ONNX_Op<"Relu",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXReshapeOp:ONNX_Op<"Reshape",
|
def ONNXReshapeOp:ONNX_Op<"Reshape",
|
||||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
[NoSideEffect,
|
||||||
|
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
|
||||||
|
OpInterface<"IdentifyConstOperandsOpInterface">]> {
|
||||||
let summary = "ONNX Reshape operation";
|
let summary = "ONNX Reshape operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Reshape the input tensor similar to numpy.reshape."
|
"Reshape the input tensor similar to numpy.reshape."
|
||||||
|
@ -2512,8 +2514,11 @@ def ONNXReshapeOp:ONNX_Op<"Reshape",
|
||||||
"from the input tensor)."
|
"from the input tensor)."
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
|
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
|
||||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$shape);
|
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$shape);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reshaped);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reshaped);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
std::map<std::string, size_t> identifyPromotableConstOperands() { return {{"shape", 1}}; }
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXResizeOp:ONNX_Op<"Resize",
|
def ONNXResizeOp:ONNX_Op<"Resize",
|
||||||
|
|
|
@ -126,6 +126,7 @@ int main(int argc, char *argv[]) {
|
||||||
pm.addPass(mlir::createShapeInferencePass());
|
pm.addPass(mlir::createShapeInferencePass());
|
||||||
pm.addPass(mlir::createCanonicalizerPass());
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
pm.addPass(mlir::createShapeInferencePass());
|
pm.addPass(mlir::createShapeInferencePass());
|
||||||
|
pm.addPass(mlir::createAttributePromotionPass());
|
||||||
|
|
||||||
if (emissionTarget >= EmitMLIR) {
|
if (emissionTarget >= EmitMLIR) {
|
||||||
pm.addPass(mlir::createLowerToKrnlPass());
|
pm.addPass(mlir::createLowerToKrnlPass());
|
||||||
|
|
|
@ -20,6 +20,8 @@ std::unique_ptr<Pass> createDecomposeONNXToONNXPass();
|
||||||
|
|
||||||
std::unique_ptr<Pass> createShapeInferencePass();
|
std::unique_ptr<Pass> createShapeInferencePass();
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> createAttributePromotionPass();
|
||||||
|
|
||||||
/// Add pass for lowering to Krnl IR.
|
/// Add pass for lowering to Krnl IR.
|
||||||
std::unique_ptr<Pass> createLowerToKrnlPass();
|
std::unique_ptr<Pass> createLowerToKrnlPass();
|
||||||
|
|
||||||
|
|
|
@ -12,9 +12,9 @@
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "llvm/ADT/SmallPtrSet.h"
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
|
||||||
#include "shape_inference_interface.hpp"
|
#include "shape_inference_interface.hpp"
|
||||||
#include "src/dialect/onnx/onnx_ops.hpp"
|
|
||||||
|
|
||||||
#include "passes.hpp"
|
#include "passes.hpp"
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue