Support attribute promotion.
This commit is contained in:
parent
d86591d61a
commit
955968b750
|
@ -12,7 +12,10 @@ add_library(compiler
|
|||
pass/onnx_combine.cpp
|
||||
pass/onnx_rewrite.cpp
|
||||
pass/onnx_decompose.cpp
|
||||
pass/passes.hpp)
|
||||
pass/passes.hpp
|
||||
dialect/onnx/const_operands_interface.hpp
|
||||
dialect/onnx/attribute_promotion.cpp
|
||||
)
|
||||
|
||||
# Include root src directory.
|
||||
target_include_directories(compiler PRIVATE ${ONNF_SRC_ROOT})
|
||||
|
@ -47,12 +50,18 @@ onnf_tablegen(onnx_rewrite.inc -gen-rewriters)
|
|||
add_public_tablegen_target(gen_onnx_rewrite)
|
||||
add_dependencies(compiler gen_onnx_rewrite)
|
||||
|
||||
add_subdirectory(dialect/onnx)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
|
||||
onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||
onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||
set(GEN_DOC_FILE ${CMAKE_BINARY_DIR}/docs/Dialects/onnx.md)
|
||||
add_public_tablegen_target(gen_onnx)
|
||||
|
||||
add_dependencies(gen_onnx gen_shape_inference)
|
||||
add_dependencies(gen_onnx gen_const_operands_interface)
|
||||
add_dependencies(compiler gen_onnx)
|
||||
|
||||
add_onnf_dialect_doc(onnx dialect/onnx/onnx.td)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS dialect/krnl/krnl_ops.td)
|
||||
|
@ -66,7 +75,7 @@ target_include_directories(onnf_onnx_decompose
|
|||
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
||||
${ONNF_SRC_ROOT})
|
||||
target_link_libraries(onnf_onnx_decompose ${MLIRLibs})
|
||||
add_dependencies(onnf_onnx_decompose gen_krnl_ops)
|
||||
add_dependencies(onnf_onnx_decompose gen_onnx)
|
||||
|
||||
add_library(onnf_shape_inference pass/shape_inference_pass.cpp)
|
||||
target_include_directories(onnf_shape_inference
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
#add_library(onnf_fold_const_operand_to_attribute fold_const_operand_to_attribute.hpp)
|
||||
#target_include_directories(onnf_fold_const_operand_to_attribute
|
||||
# PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT})
|
||||
#target_link_libraries(onnf_fold_const_operand_to_attribute ${MLIRLibs})
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS const_operands_interface.td)
|
||||
onnf_tablegen(const_operands_interface.hpp.inc -gen-op-interface-decls)
|
||||
onnf_tablegen(const_operands_interface.cpp.inc -gen-op-interface-defs)
|
||||
message("TBLGENOUT" ${TABLEGEN_OUTPUT})
|
||||
add_public_tablegen_target(gen_const_operands_interface)
|
||||
add_dependencies(compiler gen_const_operands_interface)
|
|
@ -0,0 +1,79 @@
|
|||
//===----- shape_inference_pass.cpp - Shape Inference ---------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a Function level pass performing propagation of array
|
||||
// shapes through function specialization.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "src/dialect/onnx/const_operands_interface.hpp"
|
||||
#include "src/dialect/onnx/onnx_ops.hpp"
|
||||
#include "src/pass/passes.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
#include "src/dialect/onnx/const_operands_interface.cpp.inc"
|
||||
|
||||
namespace {
|
||||
/*!
|
||||
* FunctionPass that performs shape inference by iterating over a list of
|
||||
* candidate operations and propagating the shape information until the list
|
||||
* of operations is empty [credit MLIR authors].
|
||||
*/
|
||||
class AttributePromotionPass
|
||||
: public mlir::FunctionPass<AttributePromotionPass> {
|
||||
public:
|
||||
void runOnFunction() override {
|
||||
auto f = getFunction();
|
||||
|
||||
llvm::Optional<mlir::Value> none;
|
||||
f.walk([&](mlir::Operation *op) {
|
||||
if (op->getNumResults() && op->getOpResult(0).getType().isa<NoneType>())
|
||||
none = op->getOpResult(0);
|
||||
});
|
||||
|
||||
if (!none) {
|
||||
OpBuilder builder(f.getContext());
|
||||
builder.setInsertionPointToStart(&f.front());
|
||||
none =
|
||||
builder.create<mlir::ConstantOp>(f.getLoc(), builder.getUnitAttr());
|
||||
}
|
||||
|
||||
// Iterate on the operations that need shape inference i.e the operations
|
||||
// that return a dynamic shape.
|
||||
f.walk([&](mlir::Operation *op) {
|
||||
if (IdentifyConstOperandsOpInterface opWithConstOperand =
|
||||
dyn_cast<IdentifyConstOperandsOpInterface>(op)) {
|
||||
auto promotableOperands = opWithConstOperand.identifyPromotableConstOperands();
|
||||
for (const auto& operandNameToIdx : promotableOperands) {
|
||||
auto name = operandNameToIdx.first;
|
||||
auto idx = operandNameToIdx.second;
|
||||
|
||||
auto operandToPromote = op->getOperand(idx);
|
||||
if (auto constantOp =
|
||||
dyn_cast_or_null<ConstantOp>(operandToPromote.getDefiningOp())) {
|
||||
op->setAttr(name, constantOp.value());
|
||||
op->setOperand(idx, *none);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
/*!
|
||||
* Create a Shape Inference pass.
|
||||
*/
|
||||
std::unique_ptr<mlir::Pass> mlir::createAttributePromotionPass() {
|
||||
return std::make_unique<AttributePromotionPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<AttributePromotionPass> pass(
|
||||
"attribute-promotion", "Shape inference for frontend dialects.");
|
|
@ -0,0 +1,24 @@
|
|||
//===- shape_inference_interface.hpp - Definition for ShapeInference --------=//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains the declarations of the shape inference interfaces defined
|
||||
// in ShapeInferenceInterface.td.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// Include the auto-generated declarations.
|
||||
#include "src/dialect/onnx/const_operands_interface.hpp.inc"
|
||||
|
||||
} // end namespace mlir
|
|
@ -0,0 +1,32 @@
|
|||
//=- shape_inference_interface.td - Shape Inference Interface -*- tablegen -==//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// Defines the operations of the Shape Inference Op Interface.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifdef IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
||||
#else
|
||||
#define IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
||||
|
||||
#ifdef OP_BASE
|
||||
#else
|
||||
include "mlir/IR/OpBase.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def IdentifyConstOperandsOpInterface : OpInterface<"IdentifyConstOperandsOpInterface"> {
|
||||
let description = [{
|
||||
Interface to access a registered method to infer the return types for an
|
||||
operation that can be used during type inference.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<"Infer and set the output shape for the current operation.",
|
||||
"std::map<std::string, size_t>", "identifyPromotableConstOperands">
|
||||
];
|
||||
}
|
||||
|
||||
#endif // IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
|
@ -22,6 +22,11 @@ include "mlir/IR/OpBase.td"
|
|||
include "pass/shape_inference_interface.td"
|
||||
#endif // SHAPE_INFERENCE_INTERFACE
|
||||
|
||||
#ifdef IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
||||
#else
|
||||
include "dialect/onnx/const_operands_interface.td"
|
||||
#endif // IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
||||
|
||||
def ONNX_Dialect : Dialect {
|
||||
let name = "onnx";
|
||||
let cppNamespace = "";
|
||||
|
|
|
@ -10,6 +10,9 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
@ -17,6 +20,7 @@
|
|||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
#include "src/pass/shape_inference_interface.hpp"
|
||||
#include "src/dialect/onnx/const_operands_interface.hpp"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ def ONNXAbsOp:ONNX_Op<"Abs",
|
|||
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
|
||||
build(builder, state, outputTypes, operands, attributes);
|
||||
}]>
|
||||
];
|
||||
];
|
||||
}
|
||||
|
||||
def ONNXAcosOp:ONNX_Op<"Acos",
|
||||
|
@ -2501,7 +2501,9 @@ def ONNXReluOp:ONNX_Op<"Relu",
|
|||
}
|
||||
|
||||
def ONNXReshapeOp:ONNX_Op<"Reshape",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
[NoSideEffect,
|
||||
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
|
||||
OpInterface<"IdentifyConstOperandsOpInterface">]> {
|
||||
let summary = "ONNX Reshape operation";
|
||||
let description = [{
|
||||
"Reshape the input tensor similar to numpy.reshape."
|
||||
|
@ -2512,8 +2514,11 @@ def ONNXReshapeOp:ONNX_Op<"Reshape",
|
|||
"from the input tensor)."
|
||||
}];
|
||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
|
||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$shape);
|
||||
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$shape);
|
||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reshaped);
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, size_t> identifyPromotableConstOperands() { return {{"shape", 1}}; }
|
||||
}];
|
||||
}
|
||||
|
||||
def ONNXResizeOp:ONNX_Op<"Resize",
|
||||
|
|
|
@ -126,6 +126,7 @@ int main(int argc, char *argv[]) {
|
|||
pm.addPass(mlir::createShapeInferencePass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createShapeInferencePass());
|
||||
pm.addPass(mlir::createAttributePromotionPass());
|
||||
|
||||
if (emissionTarget >= EmitMLIR) {
|
||||
pm.addPass(mlir::createLowerToKrnlPass());
|
||||
|
|
|
@ -20,6 +20,8 @@ std::unique_ptr<Pass> createDecomposeONNXToONNXPass();
|
|||
|
||||
std::unique_ptr<Pass> createShapeInferencePass();
|
||||
|
||||
std::unique_ptr<Pass> createAttributePromotionPass();
|
||||
|
||||
/// Add pass for lowering to Krnl IR.
|
||||
std::unique_ptr<Pass> createLowerToKrnlPass();
|
||||
|
||||
|
|
|
@ -12,9 +12,9 @@
|
|||
#include "mlir/Pass/Pass.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
#include "shape_inference_interface.hpp"
|
||||
#include "src/dialect/onnx/onnx_ops.hpp"
|
||||
|
||||
#include "passes.hpp"
|
||||
|
||||
|
|
Loading…
Reference in New Issue