Support attribute promotion. (#34)

* Support attribute promotion.

* Simplify op interface name.

* 1. Add more comments to Attribute Promotion Pass.
2. Move Promotable Const Operand Interface to src/interface, and link against it.

* Complete NFC change onnx -> onnx-mlir.

* Move attribute_promotion pass to src/transform.

* Nit: reword comment.

* Support Attribute Promotion in gen_doc.py.

* Add test.

* Update ONNX doc.

* Add negative test.

* Rename onnxop.inc -> onnx_ops.td.inc.

* Include onnx_ops.td.inc.

* Nit: better comments.

* Prettify CMake.

* Remove original attribute_promotion code, improve comments.

* Append '_op_interface' to op interface decl/defs.

* Namespace cmake targets using onnx_mlir_ prefix.

* Use updated header name.

* Use new body file name.

* Fix dependency.

* Use new CMake target name.

* Make attribute promotion self-contained by removing redundant constant operaions inside the pass execution.

* Remove canonicalization pass.

* Increase comments.

* Use stricter checks.

* Add one more test case.

* Remove %arg1 as it's never used.
This commit is contained in:
Tian Jin 2020-03-19 15:03:37 +08:00 committed by GitHub
parent 2814ea3898
commit 549af8f0b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 337 additions and 21 deletions

View File

@ -3749,7 +3749,7 @@ ONNX Reshape operation
#### Operands: #### Operands:
1. `data`: memref of any type values or tensor of any type values 1. `data`: memref of any type values or tensor of any type values
1. `shape`: memref of any type values or tensor of any type values 1. `shape`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:

View File

@ -47,13 +47,20 @@ OpsWithShapeInference = [
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin', 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', 'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
'Sign', 'Constant', 'ONNXAveragePoolOp', 'Abs' 'Sign', 'Constant', 'AveragePool', 'Abs'
] ]
# Operations supporting canonicalization. # Operations supporting canonicalization.
OpsWithCanonicalizer = [ OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm']
'Add', 'Identity', 'Gemm'
] # Operations who have operands that, if produced by constant operations, should
# be promoted to become an attribute (via attribute promotion).
#
# For each operation, a key/value pair is used to specify how attribute promotion
# should proceed. The key is the operation's name and the value is a list of
# tuples, whose first item is the attribute/operand name, and the second item is
# the index at which such operand occurs in the list of the operation's inputs.
OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)]}
# Add an Op in this list if the Op needs result type deduction which is required # Add an Op in this list if the Op needs result type deduction which is required
# when writing declarative rewriting rules. Deduced type is always # when writing declarative rewriting rules. Deduced type is always
@ -224,7 +231,7 @@ def get_operands_or_results(schema, is_input):
return "AnyTypeOf<[{}]>".format(", ".join(types)) return "AnyTypeOf<[{}]>".format(", ".join(types))
name_to_types = OrderedDict() name_to_types = OrderedDict()
for value in value_list: for i, value in enumerate(value_list):
elem_types = get_allowed_elem_types(schema, value) elem_types = get_allowed_elem_types(schema, value)
if elem_types is None: if elem_types is None:
@ -233,6 +240,13 @@ def get_operands_or_results(schema, is_input):
types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"] types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"]
types = list(map(lambda x: x.format(elem_types), types)) types = list(map(lambda x: x.format(elem_types), types))
# If operand is promotable to an attribute, then it must be
# nullable in case it migrates to be an attribute.
if schema.name in OpsWithPromotableConstOperands:
idxs = dict(OpsWithPromotableConstOperands[schema.name]).values()
if i in idxs:
types.append("NoneType")
if OpSchema.FormalParameterOption.Optional == value.option: if OpSchema.FormalParameterOption.Optional == value.option:
types.append("NoneType") types.append("NoneType")
elif OpSchema.FormalParameterOption.Variadic == value.option: elif OpSchema.FormalParameterOption.Variadic == value.option:
@ -313,6 +327,25 @@ def get_attrs(schema):
return name_to_type return name_to_type
def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx):
cpp_name_to_idx_literal = "{" + ", ".join([
"{{\"{}\", {}}}".format(*name_to_idx)
for name_to_idx in const_operands_name_to_idx
]) + "}"
s += indent + "let extraClassDeclaration = [{\n"
indent = inc_indent(indent)
s += indent + "std::map<std::string, size_t> promotableConstOperands() {\n"
indent = inc_indent(indent)
s += indent + "return {};\n".format(cpp_name_to_idx_literal)
indent = dec_indent(indent)
s += indent + "}\n"
indent = dec_indent(indent)
s += indent + "}];\n"
return s
def gen_op_def(schema): def gen_op_def(schema):
indent = inc_indent() indent = inc_indent()
s = 'def ONNX{0}Op:ONNX_Op<"{0}",\n'.format(schema.name) s = 'def ONNX{0}Op:ONNX_Op<"{0}",\n'.format(schema.name)
@ -321,6 +354,8 @@ def gen_op_def(schema):
traits = ["NoSideEffect"] traits = ["NoSideEffect"]
if schema.name in OpsWithShapeInference: if schema.name in OpsWithShapeInference:
traits.append("DeclareOpInterfaceMethods<ShapeInferenceOpInterface>") traits.append("DeclareOpInterfaceMethods<ShapeInferenceOpInterface>")
if schema.name in OpsWithPromotableConstOperands.keys():
traits.append("OpInterface<\"PromotableConstOperandsOpInterface\">")
s += inc_indent(indent) + '[{}]> {{\n'.format(join_args(traits)) s += inc_indent(indent) + '[{}]> {{\n'.format(join_args(traits))
# Generate decl for canonicalizer. # Generate decl for canonicalizer.
@ -400,6 +435,9 @@ def gen_op_def(schema):
s += '\n' + indent + '];\n' s += '\n' + indent + '];\n'
if schema.name in OpsWithPromotableConstOperands:
s = get_promotable_const_operands_func(
s, indent, OpsWithPromotableConstOperands[schema.name])
s += '}\n\n' s += '}\n\n'
return s return s
@ -506,7 +544,7 @@ if __name__ == '__main__':
curr_dir = os.path.dirname(os.path.realpath(__file__)) curr_dir = os.path.dirname(os.path.realpath(__file__))
class Args(object): class Args(object):
op_def_file = os.path.join(curr_dir, 'onnxop.inc') op_def_file = os.path.join(curr_dir, 'onnx_ops.td.inc')
op_importer_file = os.path.join(curr_dir, 'op_build_table.inc') op_importer_file = os.path.join(curr_dir, 'op_build_table.inc')
main(Args) main(Args)

View File

@ -8,7 +8,6 @@ add_library(compiler
dialect/krnl/krnl_helper.cpp dialect/krnl/krnl_helper.cpp
dialect/krnl/krnl_helper.hpp dialect/krnl/krnl_helper.hpp
pass/shape_inference_interface.hpp pass/shape_inference_interface.hpp
dialect/onnx/onnxop.inc
pass/onnx_combine.cpp pass/onnx_combine.cpp
pass/onnx_rewrite.cpp pass/onnx_rewrite.cpp
pass/onnx_decompose.cpp pass/onnx_decompose.cpp
@ -47,12 +46,22 @@ onnx_mlir_tablegen(onnx_rewrite.inc -gen-rewriters)
add_public_tablegen_target(gen_onnx_rewrite) add_public_tablegen_target(gen_onnx_rewrite)
add_dependencies(compiler gen_onnx_rewrite) add_dependencies(compiler gen_onnx_rewrite)
add_subdirectory(interface)
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td) set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
onnx_mlir_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass") onnx_mlir_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
onnx_mlir_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass") onnx_mlir_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")
set(GEN_DOC_FILE ${CMAKE_BINARY_DIR}/docs/Dialects/onnx.md) set(GEN_DOC_FILE ${CMAKE_BINARY_DIR}/docs/Dialects/onnx.md)
add_public_tablegen_target(gen_onnx) add_public_tablegen_target(gen_onnx)
add_dependencies(gen_onnx gen_shape_inference)
add_dependencies(compiler gen_onnx) add_dependencies(compiler gen_onnx)
# TODO: onnx_mlir_gen_promotable_const_operands_op_interface is really a
# dependency of the onnx dialect library, which is currently part of `compiler`.
add_dependencies(compiler onnx_mlir_gen_promotable_const_operands_op_interface)
add_onnx_mlir_dialect_doc(onnx dialect/onnx/onnx.td) add_onnx_mlir_dialect_doc(onnx dialect/onnx/onnx.td)
set(LLVM_TARGET_DEFINITIONS dialect/krnl/krnl_ops.td) set(LLVM_TARGET_DEFINITIONS dialect/krnl/krnl_ops.td)
@ -66,14 +75,14 @@ target_include_directories(onnx_mlir_onnx_decompose
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
${ONNX_MLIR_SRC_ROOT}) ${ONNX_MLIR_SRC_ROOT})
target_link_libraries(onnx_mlir_onnx_decompose ${MLIRLibs}) target_link_libraries(onnx_mlir_onnx_decompose ${MLIRLibs})
add_dependencies(onnx_mlir_onnx_decompose gen_krnl_ops) add_dependencies(onnx_mlir_onnx_decompose gen_onnx)
add_library(onnx_mlir_shape_inference pass/shape_inference_pass.cpp) add_library(onnx_mlir_shape_inference pass/shape_inference_pass.cpp)
target_include_directories(onnx_mlir_shape_inference target_include_directories(onnx_mlir_shape_inference
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
${ONNX_MLIR_SRC_ROOT}) ${ONNX_MLIR_SRC_ROOT})
target_link_libraries(onnx_mlir_shape_inference ${MLIRLibs}) target_link_libraries(onnx_mlir_shape_inference ${MLIRLibs})
add_dependencies(onnx_mlir_shape_inference gen_krnl_ops) add_dependencies(onnx_mlir_shape_inference gen_onnx)
add_library(onnx_mlir_lower_frontend add_library(onnx_mlir_lower_frontend
conversion/onnx_to_krnl/onnx_to_krnl_common.cpp conversion/onnx_to_krnl/onnx_to_krnl_common.cpp
@ -106,8 +115,17 @@ add_subdirectory(runtime)
add_executable(onnx-mlir main.cpp) add_executable(onnx-mlir main.cpp)
target_link_libraries(onnx-mlir builder ${MLIRLibs} onnx_mlir_transform onnx_mlir_onnx_decompose onnx_mlir_shape_inference onnx_mlir_lower_frontend) target_link_libraries(onnx-mlir
whole_archive_link_mlir(onnx-mlir ${MLIRWholeArchiveLibs}) builder
${MLIRLibs}
onnx_mlir_transform
onnx_mlir_onnx_decompose
onnx_mlir_shape_inference
onnx_mlir_lower_frontend
onnx_mlir_attribute_promotion)
whole_archive_link_mlir(onnx-mlir
${MLIRWholeArchiveLibs})
find_package(ZLIB REQUIRED) find_package(ZLIB REQUIRED)
target_link_libraries(onnx-mlir ${ZLIB_LIBRARIES}) target_link_libraries(onnx-mlir ${ZLIB_LIBRARIES})

View File

View File

@ -22,6 +22,11 @@ include "mlir/IR/OpBase.td"
include "pass/shape_inference_interface.td" include "pass/shape_inference_interface.td"
#endif // SHAPE_INFERENCE_INTERFACE #endif // SHAPE_INFERENCE_INTERFACE
#ifdef PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
#else
include "interface/promotable_const_operands_op_interface.td"
#endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
def ONNX_Dialect : Dialect { def ONNX_Dialect : Dialect {
let name = "onnx"; let name = "onnx";
let cppNamespace = ""; let cppNamespace = "";
@ -48,7 +53,7 @@ class ONNX_Op<string mnemonic, list<OpTrait> traits = []> :
// CC=gcc CXX=g++ pip install -e . // CC=gcc CXX=g++ pip install -e .
//run the script //run the script
// python onnx/defs/gen_doc.py // python onnx/defs/gen_doc.py
//result is in docs/onnxop.inc //result is in docs/onnx_ops.td.inc
//current limitations: //current limitations:
// 1. Attributes are not processed // 1. Attributes are not processed
// 2. output type inference not implemented except Add // 2. output type inference not implemented except Add
@ -56,7 +61,7 @@ class ONNX_Op<string mnemonic, list<OpTrait> traits = []> :
// 4. type of string, complex64 and complex128 for input/output are ignored // 4. type of string, complex64 and complex128 for input/output are ignored
// 5. unsigned int are treated as signed one // 5. unsigned int are treated as signed one
include "dialect/onnx/onnxop.inc" include "dialect/onnx/onnx_ops.td.inc"
// Indicate entry point functions of ONNX graph. // Indicate entry point functions of ONNX graph.
def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> { def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {

View File

@ -10,6 +10,9 @@
#pragma once #pragma once
#include <map>
#include <string>
#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
@ -17,6 +20,7 @@
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "src/pass/shape_inference_interface.hpp" #include "src/pass/shape_inference_interface.hpp"
#include "src/interface/promotable_const_operands_op_interface.hpp"
namespace mlir { namespace mlir {

View File

@ -1,5 +1,5 @@
//******************************************************** //********************************************************
// This file is generated on UTC-02/24/2020, 06:44:13. // This file is generated on UTC-03/18/2020, 03:36:58.
// Do not modify this file directly. // Do not modify this file directly.
// This file is automatically generated via script. // This file is automatically generated via script.
// Details can be found in doc/readonnxdefs.md . // Details can be found in doc/readonnxdefs.md .
@ -2501,7 +2501,7 @@ def ONNXReluOp:ONNX_Op<"Relu",
} }
def ONNXReshapeOp:ONNX_Op<"Reshape", def ONNXReshapeOp:ONNX_Op<"Reshape",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"PromotableConstOperandsOpInterface">]> {
let summary = "ONNX Reshape operation"; let summary = "ONNX Reshape operation";
let description = [{ let description = [{
"Reshape the input tensor similar to numpy.reshape." "Reshape the input tensor similar to numpy.reshape."
@ -2512,8 +2512,13 @@ def ONNXReshapeOp:ONNX_Op<"Reshape",
"from the input tensor)." "from the input tensor)."
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$shape); AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$shape);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reshaped); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reshaped);
let extraClassDeclaration = [{
std::map<std::string, size_t> promotableConstOperands() {
return {{"shape", 1}};
}
}];
} }
def ONNXResizeOp:ONNX_Op<"Resize", def ONNXResizeOp:ONNX_Op<"Resize",

View File

@ -0,0 +1,14 @@
set(LLVM_TARGET_DEFINITIONS promotable_const_operands_op_interface.td)
onnx_mlir_tablegen(promotable_const_operands_op_interface.hpp.inc -gen-op-interface-decls)
onnx_mlir_tablegen(promotable_const_operands_op_interface.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(onnx_mlir_gen_promotable_const_operands_op_interface)
add_library(onnx_mlir_promotable_const_operands_op_interface
promotable_const_operands_op_interface.hpp
promotable_const_operands_op_interface.cpp)
target_include_directories(onnx_mlir_promotable_const_operands_op_interface
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
${ONNX_MLIR_SRC_ROOT})
add_dependencies(onnx_mlir_promotable_const_operands_op_interface
onnx_mlir_gen_promotable_const_operands_op_interface)

View File

@ -0,0 +1,23 @@
//===------------ promotable_const_operands_op_interface.cpp --------------===//
//===-------- Promotable Const Operands Op Interface Definition -----------===//
//
// Copyright 2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains the definition of the promotable const operands op
// interface.
//
//===----------------------------------------------------------------------===//
#include "src/interface/promotable_const_operands_op_interface.hpp"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Promotable Const Operands Op Interface
//===----------------------------------------------------------------------===//
#include "src/interface/promotable_const_operands_op_interface.cpp.inc"

View File

@ -0,0 +1,25 @@
//===------------ promotable_const_operands_op_interface.cpp --------------===//
//===-------- Promotable Const Operands Op Interface Definition -----------===//
//
// Copyright 2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains the declaration of the promotable const operands op
// interface.
//
//===----------------------------------------------------------------------===//
#pragma once
#include <map>
#include <string>
#include "mlir/IR/OpDefinition.h"
namespace mlir {
/// Include the auto-generated declarations.
#include "src/interface/promotable_const_operands_op_interface.hpp.inc"
} // end namespace mlir

View File

@ -0,0 +1,34 @@
//===------------- promotable_const_operands_op_interface.td --------------===//
//===---- Promotable Const Operands Op Interface TableGen Definition ------===//
//
// Copyright 2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains the TableGen definition of the promotable const operands
// op interface.
//
//===----------------------------------------------------------------------===//
#ifdef PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
#else
#define PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
def PromotableConstOperandsOpInterface : OpInterface<"PromotableConstOperandsOpInterface"> {
let description = [{
Interface to access a registered method to infer the return types for an
operation that can be used during type inference.
}];
let methods = [
InterfaceMethod<"Infer and set the output shape for the current operation.",
"std::map<std::string, size_t>", "promotableConstOperands">
];
}
#endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE

View File

@ -126,6 +126,7 @@ int main(int argc, char *argv[]) {
pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createAttributePromotionPass());
if (emissionTarget >= EmitMLIR) { if (emissionTarget >= EmitMLIR) {
pm.addPass(mlir::createLowerToKrnlPass()); pm.addPass(mlir::createLowerToKrnlPass());

View File

@ -20,6 +20,9 @@ std::unique_ptr<Pass> createDecomposeONNXToONNXPass();
std::unique_ptr<Pass> createShapeInferencePass(); std::unique_ptr<Pass> createShapeInferencePass();
/// Pass for promoting constant operands to attributes.
std::unique_ptr<Pass> createAttributePromotionPass();
/// Add pass for lowering to Krnl IR. /// Add pass for lowering to Krnl IR.
std::unique_ptr<Pass> createLowerToKrnlPass(); std::unique_ptr<Pass> createLowerToKrnlPass();

View File

@ -12,9 +12,9 @@
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "mlir/IR/StandardTypes.h"
#include "shape_inference_interface.hpp" #include "shape_inference_interface.hpp"
#include "src/dialect/onnx/onnx_ops.hpp"
#include "passes.hpp" #include "passes.hpp"

View File

@ -4,6 +4,19 @@ add_dependencies(onnx-mlir-opt gen_krnl_ops)
target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_SRC_ROOT}) target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_SRC_ROOT})
target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_BIN_ROOT}) target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_BIN_ROOT})
target_link_libraries(onnx-mlir-opt builder ${MLIRLibs} onnx_mlir_transform onnx_mlir_shape_inference onnx_mlir_lower_frontend curses) target_link_libraries(onnx-mlir-opt
whole_archive_link_mlir(onnx-mlir-opt ${MLIRWholeArchiveLibs}) builder
whole_archive_link_onnx_mlir(onnx-mlir-opt compiler onnx_mlir_transform onnx_mlir_lower_frontend onnx_mlir_shape_inference) ${MLIRLibs}
onnx_mlir_transform
onnx_mlir_shape_inference
onnx_mlir_lower_frontend
onnx_mlir_promotable_const_operands_op_interface
curses)
whole_archive_link_mlir(onnx-mlir-opt
${MLIRWholeArchiveLibs})
whole_archive_link_onnx_mlir(onnx-mlir-opt
compiler
onnx_mlir_transform
onnx_mlir_lower_frontend
onnx_mlir_shape_inference
onnx_mlir_attribute_promotion)

View File

@ -7,3 +7,5 @@ target_include_directories(onnx_mlir_transform
${ONNX_MLIR_SRC_ROOT}) ${ONNX_MLIR_SRC_ROOT})
target_link_libraries(onnx_mlir_transform ${MLIRLibs}) target_link_libraries(onnx_mlir_transform ${MLIRLibs})
add_dependencies(onnx_mlir_transform gen_krnl_ops) add_dependencies(onnx_mlir_transform gen_krnl_ops)
add_subdirectory(onnx)

View File

@ -0,0 +1,7 @@
add_library(onnx_mlir_attribute_promotion
attribute_promotion.cpp)
target_include_directories(onnx_mlir_attribute_promotion
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
${ONNF_MLIR_SRC_ROOT})
target_link_libraries(onnx_mlir_attribute_promotion
onnx_mlir_promotable_const_operands_op_interface)

View File

@ -0,0 +1,92 @@
//===----- attribute_promotion.cpp - Attribute Promotion
//-------------------===//
//
// Copyright 2020 The IBM Research Authors.
//
// =============================================================================
//
// This file implements a function level pass to move an operand to become
// an attribute if desirable and legal.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "src/interface/promotable_const_operands_op_interface.hpp"
#include "src/pass/passes.hpp"
using namespace mlir;
namespace {
/*!
* Helper function to create a NoneTyped constant value if `none` is empty.
*/
void getOrCreateNoneValue(llvm::Optional<mlir::Value> &none, FuncOp f) {
if (none.hasValue())
return;
OpBuilder builder(f.getContext());
builder.setInsertionPointToStart(&f.front());
none = builder.create<mlir::ConstantOp>(f.getLoc(), builder.getUnitAttr());
}
/*!
* FunctionPass that performs attribute promotion by iterating over a list of
* candidate operations and moves constant operands to attributes whenever
* desirable (as instructed by the PromotableConstOperandsOpInterface).
*/
class AttributePromotionPass
: public mlir::FunctionPass<AttributePromotionPass> {
public:
void runOnFunction() override {
auto f = getFunction();
// A function-scope shared none value used to indicate an missing operand.
llvm::Optional<mlir::Value> none;
// Iterate on the operations that may need attribute promotion.
f.walk([&](mlir::Operation *op) {
if (PromotableConstOperandsOpInterface opWithConstOperand =
dyn_cast<PromotableConstOperandsOpInterface>(op)) {
auto promotableOperands = opWithConstOperand.promotableConstOperands();
for (const auto &operandNameToIdx : promotableOperands) {
auto name = operandNameToIdx.first;
auto i = operandNameToIdx.second;
// If the i-th operand is defined by an constant operation, then
// move it to an attribute, and use None to indicate the absence
// of the original operand value.
auto operandToPromote = op->getOperand(i);
if (auto constantOp = dyn_cast_or_null<ConstantOp>(
operandToPromote.getDefiningOp())) {
op->setAttr(name, constantOp.value());
getOrCreateNoneValue(none, f);
op->setOperand(i, *none);
}
}
}
});
// Dispatch canonicalization pattern rewriters to eliminate redundant
// constant operaions.
OwningRewritePatternList patterns;
auto *context = &getContext();
ConstantOp::getCanonicalizationPatterns(patterns, context);
applyPatternsGreedily(f, patterns);
}
};
} // end anonymous namespace
/*!
* Create a Attribute Promotion pass.
*/
std::unique_ptr<mlir::Pass> mlir::createAttributePromotionPass() {
return std::make_unique<AttributePromotionPass>();
}
static PassRegistration<AttributePromotionPass> pass(
"attribute-promotion", "Promote constant operands to attributes.");

View File

@ -0,0 +1,32 @@
// RUN: onnx-mlir-opt --attribute-promotion %s -split-input-file | FileCheck %s
func @test_should_promote_to_attribute(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%shape = constant dense<[6, 7, 42]> : tensor<3xi32>
%0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
// CHECK-LABEL: test_should_promote_to_attribute
// CHECK-NEXT: [[NONE:%.+]] = constant unit
// CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi32>} : (tensor<?x10xf32>, none) -> tensor<*xf32>
// CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
}
func @test_should_not_promote_to_attribute(%arg0 : tensor<?x10xf32>, %arg1 : tensor<*xi64>) -> tensor<*xf32> {
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<*xi64>) -> tensor<*xf32>
return %0 : tensor<*xf32>
// CHECK-LABEL: test_should_not_promote_to_attribute
// CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, %{{.*}}) : (tensor<?x10xf32>, tensor<*xi64>) -> tensor<*xf32>
// CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
}
func @test_promote_to_attribute_without_removing_const_op(%arg0 : tensor<?x10xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%shape = constant dense<[6, 7, 42]> : tensor<3xi32>
%0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi32>) -> tensor<*xf32>
%1 = "onnx.Identity"(%shape) : (tensor<3xi32>) -> tensor<*xf32>
"std.return"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> ()
// CHECK-LABEL: test_promote_to_attribute_without_removing_const_op
// CHECK-NEXT: [[NONE:%.+]] = constant unit
// CHECK-NEXT: [[SHAPE:%.+]] = constant dense<[6, 7, 42]> : tensor<3xi32>
// CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi32>} : (tensor<?x10xf32>, none) -> tensor<*xf32>
// CHECK-NEXT: [[IDENTITY:%.+]] = "onnx.Identity"([[SHAPE]]) : (tensor<3xi32>) -> tensor<*xf32>
// CHECK-NEXT: return [[RESHAPE]], [[IDENTITY]] : tensor<*xf32>, tensor<*xf32>
}