Revert "Support attribute promotion."

This reverts commit 955968b750.
This commit is contained in:
Tian Jin 2020-03-17 17:41:59 +08:00
parent 955968b750
commit c25831094e
11 changed files with 6 additions and 178 deletions

View File

@ -12,10 +12,7 @@ add_library(compiler
pass/onnx_combine.cpp pass/onnx_combine.cpp
pass/onnx_rewrite.cpp pass/onnx_rewrite.cpp
pass/onnx_decompose.cpp pass/onnx_decompose.cpp
pass/passes.hpp pass/passes.hpp)
dialect/onnx/const_operands_interface.hpp
dialect/onnx/attribute_promotion.cpp
)
# Include root src directory. # Include root src directory.
target_include_directories(compiler PRIVATE ${ONNF_SRC_ROOT}) target_include_directories(compiler PRIVATE ${ONNF_SRC_ROOT})
@ -50,18 +47,12 @@ onnf_tablegen(onnx_rewrite.inc -gen-rewriters)
add_public_tablegen_target(gen_onnx_rewrite) add_public_tablegen_target(gen_onnx_rewrite)
add_dependencies(compiler gen_onnx_rewrite) add_dependencies(compiler gen_onnx_rewrite)
add_subdirectory(dialect/onnx)
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td) set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass") onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass") onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")
set(GEN_DOC_FILE ${CMAKE_BINARY_DIR}/docs/Dialects/onnx.md) set(GEN_DOC_FILE ${CMAKE_BINARY_DIR}/docs/Dialects/onnx.md)
add_public_tablegen_target(gen_onnx) add_public_tablegen_target(gen_onnx)
add_dependencies(gen_onnx gen_shape_inference)
add_dependencies(gen_onnx gen_const_operands_interface)
add_dependencies(compiler gen_onnx) add_dependencies(compiler gen_onnx)
add_onnf_dialect_doc(onnx dialect/onnx/onnx.td) add_onnf_dialect_doc(onnx dialect/onnx/onnx.td)
set(LLVM_TARGET_DEFINITIONS dialect/krnl/krnl_ops.td) set(LLVM_TARGET_DEFINITIONS dialect/krnl/krnl_ops.td)
@ -75,7 +66,7 @@ target_include_directories(onnf_onnx_decompose
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT} PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
${ONNF_SRC_ROOT}) ${ONNF_SRC_ROOT})
target_link_libraries(onnf_onnx_decompose ${MLIRLibs}) target_link_libraries(onnf_onnx_decompose ${MLIRLibs})
add_dependencies(onnf_onnx_decompose gen_onnx) add_dependencies(onnf_onnx_decompose gen_krnl_ops)
add_library(onnf_shape_inference pass/shape_inference_pass.cpp) add_library(onnf_shape_inference pass/shape_inference_pass.cpp)
target_include_directories(onnf_shape_inference target_include_directories(onnf_shape_inference

View File

@ -1,11 +0,0 @@
#add_library(onnf_fold_const_operand_to_attribute fold_const_operand_to_attribute.hpp)
#target_include_directories(onnf_fold_const_operand_to_attribute
# PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT})
#target_link_libraries(onnf_fold_const_operand_to_attribute ${MLIRLibs})
set(LLVM_TARGET_DEFINITIONS const_operands_interface.td)
onnf_tablegen(const_operands_interface.hpp.inc -gen-op-interface-decls)
onnf_tablegen(const_operands_interface.cpp.inc -gen-op-interface-defs)
message("TBLGENOUT" ${TABLEGEN_OUTPUT})
add_public_tablegen_target(gen_const_operands_interface)
add_dependencies(compiler gen_const_operands_interface)

View File

@ -1,79 +0,0 @@
//===----- shape_inference_pass.cpp - Shape Inference ---------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file implements a Function level pass performing propagation of array
// shapes through function specialization.
//
//===----------------------------------------------------------------------===//
#include "mlir/Pass/Pass.h"
#include "llvm/Support/raw_ostream.h"
#include "src/dialect/onnx/const_operands_interface.hpp"
#include "src/dialect/onnx/onnx_ops.hpp"
#include "src/pass/passes.hpp"
using namespace mlir;
#include "src/dialect/onnx/const_operands_interface.cpp.inc"
namespace {
/*!
* FunctionPass that performs shape inference by iterating over a list of
* candidate operations and propagating the shape information until the list
* of operations is empty [credit MLIR authors].
*/
class AttributePromotionPass
: public mlir::FunctionPass<AttributePromotionPass> {
public:
void runOnFunction() override {
auto f = getFunction();
llvm::Optional<mlir::Value> none;
f.walk([&](mlir::Operation *op) {
if (op->getNumResults() && op->getOpResult(0).getType().isa<NoneType>())
none = op->getOpResult(0);
});
if (!none) {
OpBuilder builder(f.getContext());
builder.setInsertionPointToStart(&f.front());
none =
builder.create<mlir::ConstantOp>(f.getLoc(), builder.getUnitAttr());
}
// Iterate on the operations that need shape inference i.e the operations
// that return a dynamic shape.
f.walk([&](mlir::Operation *op) {
if (IdentifyConstOperandsOpInterface opWithConstOperand =
dyn_cast<IdentifyConstOperandsOpInterface>(op)) {
auto promotableOperands = opWithConstOperand.identifyPromotableConstOperands();
for (const auto& operandNameToIdx : promotableOperands) {
auto name = operandNameToIdx.first;
auto idx = operandNameToIdx.second;
auto operandToPromote = op->getOperand(idx);
if (auto constantOp =
dyn_cast_or_null<ConstantOp>(operandToPromote.getDefiningOp())) {
op->setAttr(name, constantOp.value());
op->setOperand(idx, *none);
}
}
}
});
}
};
} // end anonymous namespace
/*!
* Create a Shape Inference pass.
*/
std::unique_ptr<mlir::Pass> mlir::createAttributePromotionPass() {
return std::make_unique<AttributePromotionPass>();
}
static PassRegistration<AttributePromotionPass> pass(
"attribute-promotion", "Shape inference for frontend dialects.");

View File

@ -1,24 +0,0 @@
//===- shape_inference_interface.hpp - Definition for ShapeInference --------=//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file contains the declarations of the shape inference interfaces defined
// in ShapeInferenceInterface.td.
//
//===----------------------------------------------------------------------===//
#pragma once
#include <map>
#include <string>
#include "mlir/IR/OpDefinition.h"
namespace mlir {
/// Include the auto-generated declarations.
#include "src/dialect/onnx/const_operands_interface.hpp.inc"
} // end namespace mlir

View File

@ -1,32 +0,0 @@
//=- shape_inference_interface.td - Shape Inference Interface -*- tablegen -==//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// Defines the operations of the Shape Inference Op Interface.
//
//===----------------------------------------------------------------------===//
#ifdef IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
#else
#define IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
def IdentifyConstOperandsOpInterface : OpInterface<"IdentifyConstOperandsOpInterface"> {
let description = [{
Interface to access a registered method to infer the return types for an
operation that can be used during type inference.
}];
let methods = [
InterfaceMethod<"Infer and set the output shape for the current operation.",
"std::map<std::string, size_t>", "identifyPromotableConstOperands">
];
}
#endif // IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE

View File

@ -22,11 +22,6 @@ include "mlir/IR/OpBase.td"
include "pass/shape_inference_interface.td" include "pass/shape_inference_interface.td"
#endif // SHAPE_INFERENCE_INTERFACE #endif // SHAPE_INFERENCE_INTERFACE
#ifdef IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
#else
include "dialect/onnx/const_operands_interface.td"
#endif // IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
def ONNX_Dialect : Dialect { def ONNX_Dialect : Dialect {
let name = "onnx"; let name = "onnx";
let cppNamespace = ""; let cppNamespace = "";

View File

@ -10,9 +10,6 @@
#pragma once #pragma once
#include <map>
#include <string>
#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
@ -20,7 +17,6 @@
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "src/pass/shape_inference_interface.hpp" #include "src/pass/shape_inference_interface.hpp"
#include "src/dialect/onnx/const_operands_interface.hpp"
namespace mlir { namespace mlir {

View File

@ -2501,9 +2501,7 @@ def ONNXReluOp:ONNX_Op<"Relu",
} }
def ONNXReshapeOp:ONNX_Op<"Reshape", def ONNXReshapeOp:ONNX_Op<"Reshape",
[NoSideEffect, [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
OpInterface<"IdentifyConstOperandsOpInterface">]> {
let summary = "ONNX Reshape operation"; let summary = "ONNX Reshape operation";
let description = [{ let description = [{
"Reshape the input tensor similar to numpy.reshape." "Reshape the input tensor similar to numpy.reshape."
@ -2514,11 +2512,8 @@ def ONNXReshapeOp:ONNX_Op<"Reshape",
"from the input tensor)." "from the input tensor)."
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$shape); AnyTypeOf<[AnyMemRef, AnyTensor]>:$shape);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reshaped); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reshaped);
let extraClassDeclaration = [{
std::map<std::string, size_t> identifyPromotableConstOperands() { return {{"shape", 1}}; }
}];
} }
def ONNXResizeOp:ONNX_Op<"Resize", def ONNXResizeOp:ONNX_Op<"Resize",

View File

@ -126,7 +126,6 @@ int main(int argc, char *argv[]) {
pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createAttributePromotionPass());
if (emissionTarget >= EmitMLIR) { if (emissionTarget >= EmitMLIR) {
pm.addPass(mlir::createLowerToKrnlPass()); pm.addPass(mlir::createLowerToKrnlPass());

View File

@ -20,8 +20,6 @@ std::unique_ptr<Pass> createDecomposeONNXToONNXPass();
std::unique_ptr<Pass> createShapeInferencePass(); std::unique_ptr<Pass> createShapeInferencePass();
std::unique_ptr<Pass> createAttributePromotionPass();
/// Add pass for lowering to Krnl IR. /// Add pass for lowering to Krnl IR.
std::unique_ptr<Pass> createLowerToKrnlPass(); std::unique_ptr<Pass> createLowerToKrnlPass();

View File

@ -12,9 +12,9 @@
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "mlir/IR/StandardTypes.h"
#include "shape_inference_interface.hpp" #include "shape_inference_interface.hpp"
#include "src/dialect/onnx/onnx_ops.hpp"
#include "passes.hpp" #include "passes.hpp"