Support attribute promotion. (#34)
* Support attribute promotion. * Simplify op interface name. * 1. Add more comments to Attribute Promotion Pass. 2. Move Promotable Const Operand Interface to src/interface, and link against it. * Complete NFC change onnx -> onnx-mlir. * Move attribute_promotion pass to src/transform. * Nit: reword comment. * Support Attribute Promotion in gen_doc.py. * Add test. * Update ONNX doc. * Add negative test. * Rename onnxop.inc -> onnx_ops.td.inc. * Include onnx_ops.td.inc. * Nit: better comments. * Prettify CMake. * Remove original attribute_promotion code, improve comments. * Append '_op_interface' to op interface decl/defs. * Namespace cmake targets using onnx_mlir_ prefix. * Use updated header name. * Use new body file name. * Fix dependency. * Use new CMake target name. * Make attribute promotion self-contained by removing redundant constant operaions inside the pass execution. * Remove canonicalization pass. * Increase comments. * Use stricter checks. * Add one more test case. * Remove %arg1 as it's never used.
This commit is contained in:
parent
2814ea3898
commit
549af8f0b2
|
@ -3749,7 +3749,7 @@ ONNX Reshape operation
|
|||
#### Operands:
|
||||
|
||||
1. `data`: memref of any type values or tensor of any type values
|
||||
1. `shape`: memref of any type values or tensor of any type values
|
||||
1. `shape`: memref of any type values or tensor of any type values or none type
|
||||
|
||||
#### Attributes:
|
||||
|
||||
|
|
|
@ -47,13 +47,20 @@ OpsWithShapeInference = [
|
|||
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
||||
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
||||
'Sign', 'Constant', 'ONNXAveragePoolOp', 'Abs'
|
||||
'Sign', 'Constant', 'AveragePool', 'Abs'
|
||||
]
|
||||
|
||||
# Operations supporting canonicalization.
|
||||
OpsWithCanonicalizer = [
|
||||
'Add', 'Identity', 'Gemm'
|
||||
]
|
||||
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm']
|
||||
|
||||
# Operations who have operands that, if produced by constant operations, should
|
||||
# be promoted to become an attribute (via attribute promotion).
|
||||
#
|
||||
# For each operation, a key/value pair is used to specify how attribute promotion
|
||||
# should proceed. The key is the operation's name and the value is a list of
|
||||
# tuples, whose first item is the attribute/operand name, and the second item is
|
||||
# the index at which such operand occurs in the list of the operation's inputs.
|
||||
OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)]}
|
||||
|
||||
# Add an Op in this list if the Op needs result type deduction which is required
|
||||
# when writing declarative rewriting rules. Deduced type is always
|
||||
|
@ -224,7 +231,7 @@ def get_operands_or_results(schema, is_input):
|
|||
return "AnyTypeOf<[{}]>".format(", ".join(types))
|
||||
|
||||
name_to_types = OrderedDict()
|
||||
for value in value_list:
|
||||
for i, value in enumerate(value_list):
|
||||
elem_types = get_allowed_elem_types(schema, value)
|
||||
|
||||
if elem_types is None:
|
||||
|
@ -233,6 +240,13 @@ def get_operands_or_results(schema, is_input):
|
|||
types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"]
|
||||
types = list(map(lambda x: x.format(elem_types), types))
|
||||
|
||||
# If operand is promotable to an attribute, then it must be
|
||||
# nullable in case it migrates to be an attribute.
|
||||
if schema.name in OpsWithPromotableConstOperands:
|
||||
idxs = dict(OpsWithPromotableConstOperands[schema.name]).values()
|
||||
if i in idxs:
|
||||
types.append("NoneType")
|
||||
|
||||
if OpSchema.FormalParameterOption.Optional == value.option:
|
||||
types.append("NoneType")
|
||||
elif OpSchema.FormalParameterOption.Variadic == value.option:
|
||||
|
@ -313,6 +327,25 @@ def get_attrs(schema):
|
|||
return name_to_type
|
||||
|
||||
|
||||
def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx):
|
||||
cpp_name_to_idx_literal = "{" + ", ".join([
|
||||
"{{\"{}\", {}}}".format(*name_to_idx)
|
||||
for name_to_idx in const_operands_name_to_idx
|
||||
]) + "}"
|
||||
|
||||
s += indent + "let extraClassDeclaration = [{\n"
|
||||
indent = inc_indent(indent)
|
||||
s += indent + "std::map<std::string, size_t> promotableConstOperands() {\n"
|
||||
indent = inc_indent(indent)
|
||||
s += indent + "return {};\n".format(cpp_name_to_idx_literal)
|
||||
indent = dec_indent(indent)
|
||||
s += indent + "}\n"
|
||||
indent = dec_indent(indent)
|
||||
s += indent + "}];\n"
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def gen_op_def(schema):
|
||||
indent = inc_indent()
|
||||
s = 'def ONNX{0}Op:ONNX_Op<"{0}",\n'.format(schema.name)
|
||||
|
@ -321,6 +354,8 @@ def gen_op_def(schema):
|
|||
traits = ["NoSideEffect"]
|
||||
if schema.name in OpsWithShapeInference:
|
||||
traits.append("DeclareOpInterfaceMethods<ShapeInferenceOpInterface>")
|
||||
if schema.name in OpsWithPromotableConstOperands.keys():
|
||||
traits.append("OpInterface<\"PromotableConstOperandsOpInterface\">")
|
||||
s += inc_indent(indent) + '[{}]> {{\n'.format(join_args(traits))
|
||||
|
||||
# Generate decl for canonicalizer.
|
||||
|
@ -400,6 +435,9 @@ def gen_op_def(schema):
|
|||
|
||||
s += '\n' + indent + '];\n'
|
||||
|
||||
if schema.name in OpsWithPromotableConstOperands:
|
||||
s = get_promotable_const_operands_func(
|
||||
s, indent, OpsWithPromotableConstOperands[schema.name])
|
||||
s += '}\n\n'
|
||||
return s
|
||||
|
||||
|
@ -506,7 +544,7 @@ if __name__ == '__main__':
|
|||
curr_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
class Args(object):
|
||||
op_def_file = os.path.join(curr_dir, 'onnxop.inc')
|
||||
op_def_file = os.path.join(curr_dir, 'onnx_ops.td.inc')
|
||||
op_importer_file = os.path.join(curr_dir, 'op_build_table.inc')
|
||||
|
||||
main(Args)
|
||||
|
|
|
@ -8,7 +8,6 @@ add_library(compiler
|
|||
dialect/krnl/krnl_helper.cpp
|
||||
dialect/krnl/krnl_helper.hpp
|
||||
pass/shape_inference_interface.hpp
|
||||
dialect/onnx/onnxop.inc
|
||||
pass/onnx_combine.cpp
|
||||
pass/onnx_rewrite.cpp
|
||||
pass/onnx_decompose.cpp
|
||||
|
@ -47,12 +46,22 @@ onnx_mlir_tablegen(onnx_rewrite.inc -gen-rewriters)
|
|||
add_public_tablegen_target(gen_onnx_rewrite)
|
||||
add_dependencies(compiler gen_onnx_rewrite)
|
||||
|
||||
add_subdirectory(interface)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
|
||||
onnx_mlir_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||
onnx_mlir_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||
set(GEN_DOC_FILE ${CMAKE_BINARY_DIR}/docs/Dialects/onnx.md)
|
||||
add_public_tablegen_target(gen_onnx)
|
||||
|
||||
add_dependencies(gen_onnx gen_shape_inference)
|
||||
add_dependencies(compiler gen_onnx)
|
||||
|
||||
# TODO: onnx_mlir_gen_promotable_const_operands_op_interface is really a
|
||||
# dependency of the onnx dialect library, which is currently part of `compiler`.
|
||||
add_dependencies(compiler onnx_mlir_gen_promotable_const_operands_op_interface)
|
||||
|
||||
|
||||
add_onnx_mlir_dialect_doc(onnx dialect/onnx/onnx.td)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS dialect/krnl/krnl_ops.td)
|
||||
|
@ -66,14 +75,14 @@ target_include_directories(onnx_mlir_onnx_decompose
|
|||
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
|
||||
${ONNX_MLIR_SRC_ROOT})
|
||||
target_link_libraries(onnx_mlir_onnx_decompose ${MLIRLibs})
|
||||
add_dependencies(onnx_mlir_onnx_decompose gen_krnl_ops)
|
||||
add_dependencies(onnx_mlir_onnx_decompose gen_onnx)
|
||||
|
||||
add_library(onnx_mlir_shape_inference pass/shape_inference_pass.cpp)
|
||||
target_include_directories(onnx_mlir_shape_inference
|
||||
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
|
||||
${ONNX_MLIR_SRC_ROOT})
|
||||
target_link_libraries(onnx_mlir_shape_inference ${MLIRLibs})
|
||||
add_dependencies(onnx_mlir_shape_inference gen_krnl_ops)
|
||||
add_dependencies(onnx_mlir_shape_inference gen_onnx)
|
||||
|
||||
add_library(onnx_mlir_lower_frontend
|
||||
conversion/onnx_to_krnl/onnx_to_krnl_common.cpp
|
||||
|
@ -106,8 +115,17 @@ add_subdirectory(runtime)
|
|||
|
||||
add_executable(onnx-mlir main.cpp)
|
||||
|
||||
target_link_libraries(onnx-mlir builder ${MLIRLibs} onnx_mlir_transform onnx_mlir_onnx_decompose onnx_mlir_shape_inference onnx_mlir_lower_frontend)
|
||||
whole_archive_link_mlir(onnx-mlir ${MLIRWholeArchiveLibs})
|
||||
target_link_libraries(onnx-mlir
|
||||
builder
|
||||
${MLIRLibs}
|
||||
onnx_mlir_transform
|
||||
onnx_mlir_onnx_decompose
|
||||
onnx_mlir_shape_inference
|
||||
onnx_mlir_lower_frontend
|
||||
onnx_mlir_attribute_promotion)
|
||||
whole_archive_link_mlir(onnx-mlir
|
||||
${MLIRWholeArchiveLibs})
|
||||
|
||||
find_package(ZLIB REQUIRED)
|
||||
target_link_libraries(onnx-mlir ${ZLIB_LIBRARIES})
|
||||
|
||||
|
|
|
@ -22,6 +22,11 @@ include "mlir/IR/OpBase.td"
|
|||
include "pass/shape_inference_interface.td"
|
||||
#endif // SHAPE_INFERENCE_INTERFACE
|
||||
|
||||
#ifdef PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
||||
#else
|
||||
include "interface/promotable_const_operands_op_interface.td"
|
||||
#endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
||||
|
||||
def ONNX_Dialect : Dialect {
|
||||
let name = "onnx";
|
||||
let cppNamespace = "";
|
||||
|
@ -48,7 +53,7 @@ class ONNX_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||
// CC=gcc CXX=g++ pip install -e .
|
||||
//run the script
|
||||
// python onnx/defs/gen_doc.py
|
||||
//result is in docs/onnxop.inc
|
||||
//result is in docs/onnx_ops.td.inc
|
||||
//current limitations:
|
||||
// 1. Attributes are not processed
|
||||
// 2. output type inference not implemented except Add
|
||||
|
@ -56,7 +61,7 @@ class ONNX_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||
// 4. type of string, complex64 and complex128 for input/output are ignored
|
||||
// 5. unsigned int are treated as signed one
|
||||
|
||||
include "dialect/onnx/onnxop.inc"
|
||||
include "dialect/onnx/onnx_ops.td.inc"
|
||||
|
||||
// Indicate entry point functions of ONNX graph.
|
||||
def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
|
||||
|
|
|
@ -10,6 +10,9 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
@ -17,6 +20,7 @@
|
|||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
#include "src/pass/shape_inference_interface.hpp"
|
||||
#include "src/interface/promotable_const_operands_op_interface.hpp"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
//********************************************************
|
||||
// This file is generated on UTC-02/24/2020, 06:44:13.
|
||||
// This file is generated on UTC-03/18/2020, 03:36:58.
|
||||
// Do not modify this file directly.
|
||||
// This file is automatically generated via script.
|
||||
// Details can be found in doc/readonnxdefs.md .
|
||||
|
@ -2501,7 +2501,7 @@ def ONNXReluOp:ONNX_Op<"Relu",
|
|||
}
|
||||
|
||||
def ONNXReshapeOp:ONNX_Op<"Reshape",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"PromotableConstOperandsOpInterface">]> {
|
||||
let summary = "ONNX Reshape operation";
|
||||
let description = [{
|
||||
"Reshape the input tensor similar to numpy.reshape."
|
||||
|
@ -2512,8 +2512,13 @@ def ONNXReshapeOp:ONNX_Op<"Reshape",
|
|||
"from the input tensor)."
|
||||
}];
|
||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
|
||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$shape);
|
||||
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$shape);
|
||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reshaped);
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, size_t> promotableConstOperands() {
|
||||
return {{"shape", 1}};
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def ONNXResizeOp:ONNX_Op<"Resize",
|
|
@ -0,0 +1,14 @@
|
|||
set(LLVM_TARGET_DEFINITIONS promotable_const_operands_op_interface.td)
|
||||
onnx_mlir_tablegen(promotable_const_operands_op_interface.hpp.inc -gen-op-interface-decls)
|
||||
onnx_mlir_tablegen(promotable_const_operands_op_interface.cpp.inc -gen-op-interface-defs)
|
||||
add_public_tablegen_target(onnx_mlir_gen_promotable_const_operands_op_interface)
|
||||
|
||||
add_library(onnx_mlir_promotable_const_operands_op_interface
|
||||
promotable_const_operands_op_interface.hpp
|
||||
promotable_const_operands_op_interface.cpp)
|
||||
target_include_directories(onnx_mlir_promotable_const_operands_op_interface
|
||||
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
|
||||
${ONNX_MLIR_SRC_ROOT})
|
||||
|
||||
add_dependencies(onnx_mlir_promotable_const_operands_op_interface
|
||||
onnx_mlir_gen_promotable_const_operands_op_interface)
|
|
@ -0,0 +1,23 @@
|
|||
//===------------ promotable_const_operands_op_interface.cpp --------------===//
|
||||
//===-------- Promotable Const Operands Op Interface Definition -----------===//
|
||||
//
|
||||
// Copyright 2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains the definition of the promotable const operands op
|
||||
// interface.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
#include "src/interface/promotable_const_operands_op_interface.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Promotable Const Operands Op Interface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/interface/promotable_const_operands_op_interface.cpp.inc"
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
//===------------ promotable_const_operands_op_interface.cpp --------------===//
|
||||
//===-------- Promotable Const Operands Op Interface Definition -----------===//
|
||||
//
|
||||
// Copyright 2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains the declaration of the promotable const operands op
|
||||
// interface.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// Include the auto-generated declarations.
|
||||
#include "src/interface/promotable_const_operands_op_interface.hpp.inc"
|
||||
|
||||
} // end namespace mlir
|
|
@ -0,0 +1,34 @@
|
|||
//===------------- promotable_const_operands_op_interface.td --------------===//
|
||||
//===---- Promotable Const Operands Op Interface TableGen Definition ------===//
|
||||
//
|
||||
// Copyright 2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains the TableGen definition of the promotable const operands
|
||||
// op interface.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifdef PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
||||
#else
|
||||
#define PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
||||
|
||||
#ifdef OP_BASE
|
||||
#else
|
||||
include "mlir/IR/OpBase.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def PromotableConstOperandsOpInterface : OpInterface<"PromotableConstOperandsOpInterface"> {
|
||||
let description = [{
|
||||
Interface to access a registered method to infer the return types for an
|
||||
operation that can be used during type inference.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<"Infer and set the output shape for the current operation.",
|
||||
"std::map<std::string, size_t>", "promotableConstOperands">
|
||||
];
|
||||
}
|
||||
|
||||
#endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
|
@ -126,6 +126,7 @@ int main(int argc, char *argv[]) {
|
|||
pm.addPass(mlir::createShapeInferencePass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createShapeInferencePass());
|
||||
pm.addPass(mlir::createAttributePromotionPass());
|
||||
|
||||
if (emissionTarget >= EmitMLIR) {
|
||||
pm.addPass(mlir::createLowerToKrnlPass());
|
||||
|
|
|
@ -20,6 +20,9 @@ std::unique_ptr<Pass> createDecomposeONNXToONNXPass();
|
|||
|
||||
std::unique_ptr<Pass> createShapeInferencePass();
|
||||
|
||||
/// Pass for promoting constant operands to attributes.
|
||||
std::unique_ptr<Pass> createAttributePromotionPass();
|
||||
|
||||
/// Add pass for lowering to Krnl IR.
|
||||
std::unique_ptr<Pass> createLowerToKrnlPass();
|
||||
|
||||
|
|
|
@ -12,9 +12,9 @@
|
|||
#include "mlir/Pass/Pass.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
#include "shape_inference_interface.hpp"
|
||||
#include "src/dialect/onnx/onnx_ops.hpp"
|
||||
|
||||
#include "passes.hpp"
|
||||
|
||||
|
|
|
@ -4,6 +4,19 @@ add_dependencies(onnx-mlir-opt gen_krnl_ops)
|
|||
target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_SRC_ROOT})
|
||||
target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_BIN_ROOT})
|
||||
|
||||
target_link_libraries(onnx-mlir-opt builder ${MLIRLibs} onnx_mlir_transform onnx_mlir_shape_inference onnx_mlir_lower_frontend curses)
|
||||
whole_archive_link_mlir(onnx-mlir-opt ${MLIRWholeArchiveLibs})
|
||||
whole_archive_link_onnx_mlir(onnx-mlir-opt compiler onnx_mlir_transform onnx_mlir_lower_frontend onnx_mlir_shape_inference)
|
||||
target_link_libraries(onnx-mlir-opt
|
||||
builder
|
||||
${MLIRLibs}
|
||||
onnx_mlir_transform
|
||||
onnx_mlir_shape_inference
|
||||
onnx_mlir_lower_frontend
|
||||
onnx_mlir_promotable_const_operands_op_interface
|
||||
curses)
|
||||
whole_archive_link_mlir(onnx-mlir-opt
|
||||
${MLIRWholeArchiveLibs})
|
||||
whole_archive_link_onnx_mlir(onnx-mlir-opt
|
||||
compiler
|
||||
onnx_mlir_transform
|
||||
onnx_mlir_lower_frontend
|
||||
onnx_mlir_shape_inference
|
||||
onnx_mlir_attribute_promotion)
|
|
@ -7,3 +7,5 @@ target_include_directories(onnx_mlir_transform
|
|||
${ONNX_MLIR_SRC_ROOT})
|
||||
target_link_libraries(onnx_mlir_transform ${MLIRLibs})
|
||||
add_dependencies(onnx_mlir_transform gen_krnl_ops)
|
||||
|
||||
add_subdirectory(onnx)
|
|
@ -0,0 +1,7 @@
|
|||
add_library(onnx_mlir_attribute_promotion
|
||||
attribute_promotion.cpp)
|
||||
target_include_directories(onnx_mlir_attribute_promotion
|
||||
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
|
||||
${ONNF_MLIR_SRC_ROOT})
|
||||
target_link_libraries(onnx_mlir_attribute_promotion
|
||||
onnx_mlir_promotable_const_operands_op_interface)
|
|
@ -0,0 +1,92 @@
|
|||
//===----- attribute_promotion.cpp - Attribute Promotion
|
||||
//-------------------===//
|
||||
//
|
||||
// Copyright 2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a function level pass to move an operand to become
|
||||
// an attribute if desirable and legal.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "src/interface/promotable_const_operands_op_interface.hpp"
|
||||
#include "src/pass/passes.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
/*!
|
||||
* Helper function to create a NoneTyped constant value if `none` is empty.
|
||||
*/
|
||||
void getOrCreateNoneValue(llvm::Optional<mlir::Value> &none, FuncOp f) {
|
||||
if (none.hasValue())
|
||||
return;
|
||||
|
||||
OpBuilder builder(f.getContext());
|
||||
builder.setInsertionPointToStart(&f.front());
|
||||
none = builder.create<mlir::ConstantOp>(f.getLoc(), builder.getUnitAttr());
|
||||
}
|
||||
|
||||
/*!
|
||||
* FunctionPass that performs attribute promotion by iterating over a list of
|
||||
* candidate operations and moves constant operands to attributes whenever
|
||||
* desirable (as instructed by the PromotableConstOperandsOpInterface).
|
||||
*/
|
||||
class AttributePromotionPass
|
||||
: public mlir::FunctionPass<AttributePromotionPass> {
|
||||
public:
|
||||
void runOnFunction() override {
|
||||
auto f = getFunction();
|
||||
|
||||
// A function-scope shared none value used to indicate an missing operand.
|
||||
llvm::Optional<mlir::Value> none;
|
||||
|
||||
// Iterate on the operations that may need attribute promotion.
|
||||
f.walk([&](mlir::Operation *op) {
|
||||
if (PromotableConstOperandsOpInterface opWithConstOperand =
|
||||
dyn_cast<PromotableConstOperandsOpInterface>(op)) {
|
||||
auto promotableOperands = opWithConstOperand.promotableConstOperands();
|
||||
for (const auto &operandNameToIdx : promotableOperands) {
|
||||
auto name = operandNameToIdx.first;
|
||||
auto i = operandNameToIdx.second;
|
||||
|
||||
// If the i-th operand is defined by an constant operation, then
|
||||
// move it to an attribute, and use None to indicate the absence
|
||||
// of the original operand value.
|
||||
auto operandToPromote = op->getOperand(i);
|
||||
if (auto constantOp = dyn_cast_or_null<ConstantOp>(
|
||||
operandToPromote.getDefiningOp())) {
|
||||
op->setAttr(name, constantOp.value());
|
||||
getOrCreateNoneValue(none, f);
|
||||
op->setOperand(i, *none);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Dispatch canonicalization pattern rewriters to eliminate redundant
|
||||
// constant operaions.
|
||||
OwningRewritePatternList patterns;
|
||||
auto *context = &getContext();
|
||||
ConstantOp::getCanonicalizationPatterns(patterns, context);
|
||||
applyPatternsGreedily(f, patterns);
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
/*!
|
||||
* Create a Attribute Promotion pass.
|
||||
*/
|
||||
std::unique_ptr<mlir::Pass> mlir::createAttributePromotionPass() {
|
||||
return std::make_unique<AttributePromotionPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<AttributePromotionPass> pass(
|
||||
"attribute-promotion", "Promote constant operands to attributes.");
|
|
@ -0,0 +1,32 @@
|
|||
// RUN: onnx-mlir-opt --attribute-promotion %s -split-input-file | FileCheck %s
|
||||
|
||||
func @test_should_promote_to_attribute(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||
%shape = constant dense<[6, 7, 42]> : tensor<3xi32>
|
||||
%0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
// CHECK-LABEL: test_should_promote_to_attribute
|
||||
// CHECK-NEXT: [[NONE:%.+]] = constant unit
|
||||
// CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi32>} : (tensor<?x10xf32>, none) -> tensor<*xf32>
|
||||
// CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @test_should_not_promote_to_attribute(%arg0 : tensor<?x10xf32>, %arg1 : tensor<*xi64>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<*xi64>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
// CHECK-LABEL: test_should_not_promote_to_attribute
|
||||
// CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, %{{.*}}) : (tensor<?x10xf32>, tensor<*xi64>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @test_promote_to_attribute_without_removing_const_op(%arg0 : tensor<?x10xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%shape = constant dense<[6, 7, 42]> : tensor<3xi32>
|
||||
%0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi32>) -> tensor<*xf32>
|
||||
%1 = "onnx.Identity"(%shape) : (tensor<3xi32>) -> tensor<*xf32>
|
||||
"std.return"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> ()
|
||||
// CHECK-LABEL: test_promote_to_attribute_without_removing_const_op
|
||||
// CHECK-NEXT: [[NONE:%.+]] = constant unit
|
||||
// CHECK-NEXT: [[SHAPE:%.+]] = constant dense<[6, 7, 42]> : tensor<3xi32>
|
||||
// CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi32>} : (tensor<?x10xf32>, none) -> tensor<*xf32>
|
||||
// CHECK-NEXT: [[IDENTITY:%.+]] = "onnx.Identity"([[SHAPE]]) : (tensor<3xi32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: return [[RESHAPE]], [[IDENTITY]] : tensor<*xf32>, tensor<*xf32>
|
||||
}
|
Loading…
Reference in New Issue