Support attribute promotion. (#34)
* Support attribute promotion. * Simplify op interface name. * 1. Add more comments to Attribute Promotion Pass. 2. Move Promotable Const Operand Interface to src/interface, and link against it. * Complete NFC change onnx -> onnx-mlir. * Move attribute_promotion pass to src/transform. * Nit: reword comment. * Support Attribute Promotion in gen_doc.py. * Add test. * Update ONNX doc. * Add negative test. * Rename onnxop.inc -> onnx_ops.td.inc. * Include onnx_ops.td.inc. * Nit: better comments. * Prettify CMake. * Remove original attribute_promotion code, improve comments. * Append '_op_interface' to op interface decl/defs. * Namespace cmake targets using onnx_mlir_ prefix. * Use updated header name. * Use new body file name. * Fix dependency. * Use new CMake target name. * Make attribute promotion self-contained by removing redundant constant operaions inside the pass execution. * Remove canonicalization pass. * Increase comments. * Use stricter checks. * Add one more test case. * Remove %arg1 as it's never used.
This commit is contained in:
parent
2814ea3898
commit
549af8f0b2
|
@ -3749,7 +3749,7 @@ ONNX Reshape operation
|
||||||
#### Operands:
|
#### Operands:
|
||||||
|
|
||||||
1. `data`: memref of any type values or tensor of any type values
|
1. `data`: memref of any type values or tensor of any type values
|
||||||
1. `shape`: memref of any type values or tensor of any type values
|
1. `shape`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
|
|
@ -47,13 +47,20 @@ OpsWithShapeInference = [
|
||||||
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
||||||
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
||||||
'Sign', 'Constant', 'ONNXAveragePoolOp', 'Abs'
|
'Sign', 'Constant', 'AveragePool', 'Abs'
|
||||||
]
|
]
|
||||||
|
|
||||||
# Operations supporting canonicalization.
|
# Operations supporting canonicalization.
|
||||||
OpsWithCanonicalizer = [
|
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm']
|
||||||
'Add', 'Identity', 'Gemm'
|
|
||||||
]
|
# Operations who have operands that, if produced by constant operations, should
|
||||||
|
# be promoted to become an attribute (via attribute promotion).
|
||||||
|
#
|
||||||
|
# For each operation, a key/value pair is used to specify how attribute promotion
|
||||||
|
# should proceed. The key is the operation's name and the value is a list of
|
||||||
|
# tuples, whose first item is the attribute/operand name, and the second item is
|
||||||
|
# the index at which such operand occurs in the list of the operation's inputs.
|
||||||
|
OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)]}
|
||||||
|
|
||||||
# Add an Op in this list if the Op needs result type deduction which is required
|
# Add an Op in this list if the Op needs result type deduction which is required
|
||||||
# when writing declarative rewriting rules. Deduced type is always
|
# when writing declarative rewriting rules. Deduced type is always
|
||||||
|
@ -224,7 +231,7 @@ def get_operands_or_results(schema, is_input):
|
||||||
return "AnyTypeOf<[{}]>".format(", ".join(types))
|
return "AnyTypeOf<[{}]>".format(", ".join(types))
|
||||||
|
|
||||||
name_to_types = OrderedDict()
|
name_to_types = OrderedDict()
|
||||||
for value in value_list:
|
for i, value in enumerate(value_list):
|
||||||
elem_types = get_allowed_elem_types(schema, value)
|
elem_types = get_allowed_elem_types(schema, value)
|
||||||
|
|
||||||
if elem_types is None:
|
if elem_types is None:
|
||||||
|
@ -233,6 +240,13 @@ def get_operands_or_results(schema, is_input):
|
||||||
types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"]
|
types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"]
|
||||||
types = list(map(lambda x: x.format(elem_types), types))
|
types = list(map(lambda x: x.format(elem_types), types))
|
||||||
|
|
||||||
|
# If operand is promotable to an attribute, then it must be
|
||||||
|
# nullable in case it migrates to be an attribute.
|
||||||
|
if schema.name in OpsWithPromotableConstOperands:
|
||||||
|
idxs = dict(OpsWithPromotableConstOperands[schema.name]).values()
|
||||||
|
if i in idxs:
|
||||||
|
types.append("NoneType")
|
||||||
|
|
||||||
if OpSchema.FormalParameterOption.Optional == value.option:
|
if OpSchema.FormalParameterOption.Optional == value.option:
|
||||||
types.append("NoneType")
|
types.append("NoneType")
|
||||||
elif OpSchema.FormalParameterOption.Variadic == value.option:
|
elif OpSchema.FormalParameterOption.Variadic == value.option:
|
||||||
|
@ -313,6 +327,25 @@ def get_attrs(schema):
|
||||||
return name_to_type
|
return name_to_type
|
||||||
|
|
||||||
|
|
||||||
|
def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx):
|
||||||
|
cpp_name_to_idx_literal = "{" + ", ".join([
|
||||||
|
"{{\"{}\", {}}}".format(*name_to_idx)
|
||||||
|
for name_to_idx in const_operands_name_to_idx
|
||||||
|
]) + "}"
|
||||||
|
|
||||||
|
s += indent + "let extraClassDeclaration = [{\n"
|
||||||
|
indent = inc_indent(indent)
|
||||||
|
s += indent + "std::map<std::string, size_t> promotableConstOperands() {\n"
|
||||||
|
indent = inc_indent(indent)
|
||||||
|
s += indent + "return {};\n".format(cpp_name_to_idx_literal)
|
||||||
|
indent = dec_indent(indent)
|
||||||
|
s += indent + "}\n"
|
||||||
|
indent = dec_indent(indent)
|
||||||
|
s += indent + "}];\n"
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
def gen_op_def(schema):
|
def gen_op_def(schema):
|
||||||
indent = inc_indent()
|
indent = inc_indent()
|
||||||
s = 'def ONNX{0}Op:ONNX_Op<"{0}",\n'.format(schema.name)
|
s = 'def ONNX{0}Op:ONNX_Op<"{0}",\n'.format(schema.name)
|
||||||
|
@ -321,6 +354,8 @@ def gen_op_def(schema):
|
||||||
traits = ["NoSideEffect"]
|
traits = ["NoSideEffect"]
|
||||||
if schema.name in OpsWithShapeInference:
|
if schema.name in OpsWithShapeInference:
|
||||||
traits.append("DeclareOpInterfaceMethods<ShapeInferenceOpInterface>")
|
traits.append("DeclareOpInterfaceMethods<ShapeInferenceOpInterface>")
|
||||||
|
if schema.name in OpsWithPromotableConstOperands.keys():
|
||||||
|
traits.append("OpInterface<\"PromotableConstOperandsOpInterface\">")
|
||||||
s += inc_indent(indent) + '[{}]> {{\n'.format(join_args(traits))
|
s += inc_indent(indent) + '[{}]> {{\n'.format(join_args(traits))
|
||||||
|
|
||||||
# Generate decl for canonicalizer.
|
# Generate decl for canonicalizer.
|
||||||
|
@ -400,6 +435,9 @@ def gen_op_def(schema):
|
||||||
|
|
||||||
s += '\n' + indent + '];\n'
|
s += '\n' + indent + '];\n'
|
||||||
|
|
||||||
|
if schema.name in OpsWithPromotableConstOperands:
|
||||||
|
s = get_promotable_const_operands_func(
|
||||||
|
s, indent, OpsWithPromotableConstOperands[schema.name])
|
||||||
s += '}\n\n'
|
s += '}\n\n'
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
@ -506,7 +544,7 @@ if __name__ == '__main__':
|
||||||
curr_dir = os.path.dirname(os.path.realpath(__file__))
|
curr_dir = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
||||||
class Args(object):
|
class Args(object):
|
||||||
op_def_file = os.path.join(curr_dir, 'onnxop.inc')
|
op_def_file = os.path.join(curr_dir, 'onnx_ops.td.inc')
|
||||||
op_importer_file = os.path.join(curr_dir, 'op_build_table.inc')
|
op_importer_file = os.path.join(curr_dir, 'op_build_table.inc')
|
||||||
|
|
||||||
main(Args)
|
main(Args)
|
||||||
|
|
|
@ -8,7 +8,6 @@ add_library(compiler
|
||||||
dialect/krnl/krnl_helper.cpp
|
dialect/krnl/krnl_helper.cpp
|
||||||
dialect/krnl/krnl_helper.hpp
|
dialect/krnl/krnl_helper.hpp
|
||||||
pass/shape_inference_interface.hpp
|
pass/shape_inference_interface.hpp
|
||||||
dialect/onnx/onnxop.inc
|
|
||||||
pass/onnx_combine.cpp
|
pass/onnx_combine.cpp
|
||||||
pass/onnx_rewrite.cpp
|
pass/onnx_rewrite.cpp
|
||||||
pass/onnx_decompose.cpp
|
pass/onnx_decompose.cpp
|
||||||
|
@ -47,12 +46,22 @@ onnx_mlir_tablegen(onnx_rewrite.inc -gen-rewriters)
|
||||||
add_public_tablegen_target(gen_onnx_rewrite)
|
add_public_tablegen_target(gen_onnx_rewrite)
|
||||||
add_dependencies(compiler gen_onnx_rewrite)
|
add_dependencies(compiler gen_onnx_rewrite)
|
||||||
|
|
||||||
|
add_subdirectory(interface)
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
|
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
|
||||||
onnx_mlir_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
onnx_mlir_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||||
onnx_mlir_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
onnx_mlir_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||||
set(GEN_DOC_FILE ${CMAKE_BINARY_DIR}/docs/Dialects/onnx.md)
|
set(GEN_DOC_FILE ${CMAKE_BINARY_DIR}/docs/Dialects/onnx.md)
|
||||||
add_public_tablegen_target(gen_onnx)
|
add_public_tablegen_target(gen_onnx)
|
||||||
|
|
||||||
|
add_dependencies(gen_onnx gen_shape_inference)
|
||||||
add_dependencies(compiler gen_onnx)
|
add_dependencies(compiler gen_onnx)
|
||||||
|
|
||||||
|
# TODO: onnx_mlir_gen_promotable_const_operands_op_interface is really a
|
||||||
|
# dependency of the onnx dialect library, which is currently part of `compiler`.
|
||||||
|
add_dependencies(compiler onnx_mlir_gen_promotable_const_operands_op_interface)
|
||||||
|
|
||||||
|
|
||||||
add_onnx_mlir_dialect_doc(onnx dialect/onnx/onnx.td)
|
add_onnx_mlir_dialect_doc(onnx dialect/onnx/onnx.td)
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS dialect/krnl/krnl_ops.td)
|
set(LLVM_TARGET_DEFINITIONS dialect/krnl/krnl_ops.td)
|
||||||
|
@ -66,14 +75,14 @@ target_include_directories(onnx_mlir_onnx_decompose
|
||||||
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
|
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
|
||||||
${ONNX_MLIR_SRC_ROOT})
|
${ONNX_MLIR_SRC_ROOT})
|
||||||
target_link_libraries(onnx_mlir_onnx_decompose ${MLIRLibs})
|
target_link_libraries(onnx_mlir_onnx_decompose ${MLIRLibs})
|
||||||
add_dependencies(onnx_mlir_onnx_decompose gen_krnl_ops)
|
add_dependencies(onnx_mlir_onnx_decompose gen_onnx)
|
||||||
|
|
||||||
add_library(onnx_mlir_shape_inference pass/shape_inference_pass.cpp)
|
add_library(onnx_mlir_shape_inference pass/shape_inference_pass.cpp)
|
||||||
target_include_directories(onnx_mlir_shape_inference
|
target_include_directories(onnx_mlir_shape_inference
|
||||||
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
|
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
|
||||||
${ONNX_MLIR_SRC_ROOT})
|
${ONNX_MLIR_SRC_ROOT})
|
||||||
target_link_libraries(onnx_mlir_shape_inference ${MLIRLibs})
|
target_link_libraries(onnx_mlir_shape_inference ${MLIRLibs})
|
||||||
add_dependencies(onnx_mlir_shape_inference gen_krnl_ops)
|
add_dependencies(onnx_mlir_shape_inference gen_onnx)
|
||||||
|
|
||||||
add_library(onnx_mlir_lower_frontend
|
add_library(onnx_mlir_lower_frontend
|
||||||
conversion/onnx_to_krnl/onnx_to_krnl_common.cpp
|
conversion/onnx_to_krnl/onnx_to_krnl_common.cpp
|
||||||
|
@ -106,8 +115,17 @@ add_subdirectory(runtime)
|
||||||
|
|
||||||
add_executable(onnx-mlir main.cpp)
|
add_executable(onnx-mlir main.cpp)
|
||||||
|
|
||||||
target_link_libraries(onnx-mlir builder ${MLIRLibs} onnx_mlir_transform onnx_mlir_onnx_decompose onnx_mlir_shape_inference onnx_mlir_lower_frontend)
|
target_link_libraries(onnx-mlir
|
||||||
whole_archive_link_mlir(onnx-mlir ${MLIRWholeArchiveLibs})
|
builder
|
||||||
|
${MLIRLibs}
|
||||||
|
onnx_mlir_transform
|
||||||
|
onnx_mlir_onnx_decompose
|
||||||
|
onnx_mlir_shape_inference
|
||||||
|
onnx_mlir_lower_frontend
|
||||||
|
onnx_mlir_attribute_promotion)
|
||||||
|
whole_archive_link_mlir(onnx-mlir
|
||||||
|
${MLIRWholeArchiveLibs})
|
||||||
|
|
||||||
find_package(ZLIB REQUIRED)
|
find_package(ZLIB REQUIRED)
|
||||||
target_link_libraries(onnx-mlir ${ZLIB_LIBRARIES})
|
target_link_libraries(onnx-mlir ${ZLIB_LIBRARIES})
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,11 @@ include "mlir/IR/OpBase.td"
|
||||||
include "pass/shape_inference_interface.td"
|
include "pass/shape_inference_interface.td"
|
||||||
#endif // SHAPE_INFERENCE_INTERFACE
|
#endif // SHAPE_INFERENCE_INTERFACE
|
||||||
|
|
||||||
|
#ifdef PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
||||||
|
#else
|
||||||
|
include "interface/promotable_const_operands_op_interface.td"
|
||||||
|
#endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
||||||
|
|
||||||
def ONNX_Dialect : Dialect {
|
def ONNX_Dialect : Dialect {
|
||||||
let name = "onnx";
|
let name = "onnx";
|
||||||
let cppNamespace = "";
|
let cppNamespace = "";
|
||||||
|
@ -48,7 +53,7 @@ class ONNX_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||||
// CC=gcc CXX=g++ pip install -e .
|
// CC=gcc CXX=g++ pip install -e .
|
||||||
//run the script
|
//run the script
|
||||||
// python onnx/defs/gen_doc.py
|
// python onnx/defs/gen_doc.py
|
||||||
//result is in docs/onnxop.inc
|
//result is in docs/onnx_ops.td.inc
|
||||||
//current limitations:
|
//current limitations:
|
||||||
// 1. Attributes are not processed
|
// 1. Attributes are not processed
|
||||||
// 2. output type inference not implemented except Add
|
// 2. output type inference not implemented except Add
|
||||||
|
@ -56,7 +61,7 @@ class ONNX_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||||
// 4. type of string, complex64 and complex128 for input/output are ignored
|
// 4. type of string, complex64 and complex128 for input/output are ignored
|
||||||
// 5. unsigned int are treated as signed one
|
// 5. unsigned int are treated as signed one
|
||||||
|
|
||||||
include "dialect/onnx/onnxop.inc"
|
include "dialect/onnx/onnx_ops.td.inc"
|
||||||
|
|
||||||
// Indicate entry point functions of ONNX graph.
|
// Indicate entry point functions of ONNX graph.
|
||||||
def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
|
def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
|
||||||
|
|
|
@ -10,6 +10,9 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
|
@ -17,6 +20,7 @@
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
|
||||||
#include "src/pass/shape_inference_interface.hpp"
|
#include "src/pass/shape_inference_interface.hpp"
|
||||||
|
#include "src/interface/promotable_const_operands_op_interface.hpp"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
//********************************************************
|
//********************************************************
|
||||||
// This file is generated on UTC-02/24/2020, 06:44:13.
|
// This file is generated on UTC-03/18/2020, 03:36:58.
|
||||||
// Do not modify this file directly.
|
// Do not modify this file directly.
|
||||||
// This file is automatically generated via script.
|
// This file is automatically generated via script.
|
||||||
// Details can be found in doc/readonnxdefs.md .
|
// Details can be found in doc/readonnxdefs.md .
|
||||||
|
@ -2501,7 +2501,7 @@ def ONNXReluOp:ONNX_Op<"Relu",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXReshapeOp:ONNX_Op<"Reshape",
|
def ONNXReshapeOp:ONNX_Op<"Reshape",
|
||||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"PromotableConstOperandsOpInterface">]> {
|
||||||
let summary = "ONNX Reshape operation";
|
let summary = "ONNX Reshape operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Reshape the input tensor similar to numpy.reshape."
|
"Reshape the input tensor similar to numpy.reshape."
|
||||||
|
@ -2512,8 +2512,13 @@ def ONNXReshapeOp:ONNX_Op<"Reshape",
|
||||||
"from the input tensor)."
|
"from the input tensor)."
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
|
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
|
||||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$shape);
|
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$shape);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reshaped);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reshaped);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
std::map<std::string, size_t> promotableConstOperands() {
|
||||||
|
return {{"shape", 1}};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXResizeOp:ONNX_Op<"Resize",
|
def ONNXResizeOp:ONNX_Op<"Resize",
|
|
@ -0,0 +1,14 @@
|
||||||
|
set(LLVM_TARGET_DEFINITIONS promotable_const_operands_op_interface.td)
|
||||||
|
onnx_mlir_tablegen(promotable_const_operands_op_interface.hpp.inc -gen-op-interface-decls)
|
||||||
|
onnx_mlir_tablegen(promotable_const_operands_op_interface.cpp.inc -gen-op-interface-defs)
|
||||||
|
add_public_tablegen_target(onnx_mlir_gen_promotable_const_operands_op_interface)
|
||||||
|
|
||||||
|
add_library(onnx_mlir_promotable_const_operands_op_interface
|
||||||
|
promotable_const_operands_op_interface.hpp
|
||||||
|
promotable_const_operands_op_interface.cpp)
|
||||||
|
target_include_directories(onnx_mlir_promotable_const_operands_op_interface
|
||||||
|
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
|
||||||
|
${ONNX_MLIR_SRC_ROOT})
|
||||||
|
|
||||||
|
add_dependencies(onnx_mlir_promotable_const_operands_op_interface
|
||||||
|
onnx_mlir_gen_promotable_const_operands_op_interface)
|
|
@ -0,0 +1,23 @@
|
||||||
|
//===------------ promotable_const_operands_op_interface.cpp --------------===//
|
||||||
|
//===-------- Promotable Const Operands Op Interface Definition -----------===//
|
||||||
|
//
|
||||||
|
// Copyright 2020 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file contains the definition of the promotable const operands op
|
||||||
|
// interface.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
|
||||||
|
#include "src/interface/promotable_const_operands_op_interface.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Promotable Const Operands Op Interface
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/interface/promotable_const_operands_op_interface.cpp.inc"
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
//===------------ promotable_const_operands_op_interface.cpp --------------===//
|
||||||
|
//===-------- Promotable Const Operands Op Interface Definition -----------===//
|
||||||
|
//
|
||||||
|
// Copyright 2020 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file contains the declaration of the promotable const operands op
|
||||||
|
// interface.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
/// Include the auto-generated declarations.
|
||||||
|
#include "src/interface/promotable_const_operands_op_interface.hpp.inc"
|
||||||
|
|
||||||
|
} // end namespace mlir
|
|
@ -0,0 +1,34 @@
|
||||||
|
//===------------- promotable_const_operands_op_interface.td --------------===//
|
||||||
|
//===---- Promotable Const Operands Op Interface TableGen Definition ------===//
|
||||||
|
//
|
||||||
|
// Copyright 2020 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file contains the TableGen definition of the promotable const operands
|
||||||
|
// op interface.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifdef PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
||||||
|
#else
|
||||||
|
#define PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
||||||
|
|
||||||
|
#ifdef OP_BASE
|
||||||
|
#else
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
#endif // OP_BASE
|
||||||
|
|
||||||
|
def PromotableConstOperandsOpInterface : OpInterface<"PromotableConstOperandsOpInterface"> {
|
||||||
|
let description = [{
|
||||||
|
Interface to access a registered method to infer the return types for an
|
||||||
|
operation that can be used during type inference.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let methods = [
|
||||||
|
InterfaceMethod<"Infer and set the output shape for the current operation.",
|
||||||
|
"std::map<std::string, size_t>", "promotableConstOperands">
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
|
@ -126,6 +126,7 @@ int main(int argc, char *argv[]) {
|
||||||
pm.addPass(mlir::createShapeInferencePass());
|
pm.addPass(mlir::createShapeInferencePass());
|
||||||
pm.addPass(mlir::createCanonicalizerPass());
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
pm.addPass(mlir::createShapeInferencePass());
|
pm.addPass(mlir::createShapeInferencePass());
|
||||||
|
pm.addPass(mlir::createAttributePromotionPass());
|
||||||
|
|
||||||
if (emissionTarget >= EmitMLIR) {
|
if (emissionTarget >= EmitMLIR) {
|
||||||
pm.addPass(mlir::createLowerToKrnlPass());
|
pm.addPass(mlir::createLowerToKrnlPass());
|
||||||
|
|
|
@ -20,6 +20,9 @@ std::unique_ptr<Pass> createDecomposeONNXToONNXPass();
|
||||||
|
|
||||||
std::unique_ptr<Pass> createShapeInferencePass();
|
std::unique_ptr<Pass> createShapeInferencePass();
|
||||||
|
|
||||||
|
/// Pass for promoting constant operands to attributes.
|
||||||
|
std::unique_ptr<Pass> createAttributePromotionPass();
|
||||||
|
|
||||||
/// Add pass for lowering to Krnl IR.
|
/// Add pass for lowering to Krnl IR.
|
||||||
std::unique_ptr<Pass> createLowerToKrnlPass();
|
std::unique_ptr<Pass> createLowerToKrnlPass();
|
||||||
|
|
||||||
|
|
|
@ -12,9 +12,9 @@
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "llvm/ADT/SmallPtrSet.h"
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
|
||||||
#include "shape_inference_interface.hpp"
|
#include "shape_inference_interface.hpp"
|
||||||
#include "src/dialect/onnx/onnx_ops.hpp"
|
|
||||||
|
|
||||||
#include "passes.hpp"
|
#include "passes.hpp"
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,19 @@ add_dependencies(onnx-mlir-opt gen_krnl_ops)
|
||||||
target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_SRC_ROOT})
|
target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_SRC_ROOT})
|
||||||
target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_BIN_ROOT})
|
target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_BIN_ROOT})
|
||||||
|
|
||||||
target_link_libraries(onnx-mlir-opt builder ${MLIRLibs} onnx_mlir_transform onnx_mlir_shape_inference onnx_mlir_lower_frontend curses)
|
target_link_libraries(onnx-mlir-opt
|
||||||
whole_archive_link_mlir(onnx-mlir-opt ${MLIRWholeArchiveLibs})
|
builder
|
||||||
whole_archive_link_onnx_mlir(onnx-mlir-opt compiler onnx_mlir_transform onnx_mlir_lower_frontend onnx_mlir_shape_inference)
|
${MLIRLibs}
|
||||||
|
onnx_mlir_transform
|
||||||
|
onnx_mlir_shape_inference
|
||||||
|
onnx_mlir_lower_frontend
|
||||||
|
onnx_mlir_promotable_const_operands_op_interface
|
||||||
|
curses)
|
||||||
|
whole_archive_link_mlir(onnx-mlir-opt
|
||||||
|
${MLIRWholeArchiveLibs})
|
||||||
|
whole_archive_link_onnx_mlir(onnx-mlir-opt
|
||||||
|
compiler
|
||||||
|
onnx_mlir_transform
|
||||||
|
onnx_mlir_lower_frontend
|
||||||
|
onnx_mlir_shape_inference
|
||||||
|
onnx_mlir_attribute_promotion)
|
|
@ -7,3 +7,5 @@ target_include_directories(onnx_mlir_transform
|
||||||
${ONNX_MLIR_SRC_ROOT})
|
${ONNX_MLIR_SRC_ROOT})
|
||||||
target_link_libraries(onnx_mlir_transform ${MLIRLibs})
|
target_link_libraries(onnx_mlir_transform ${MLIRLibs})
|
||||||
add_dependencies(onnx_mlir_transform gen_krnl_ops)
|
add_dependencies(onnx_mlir_transform gen_krnl_ops)
|
||||||
|
|
||||||
|
add_subdirectory(onnx)
|
|
@ -0,0 +1,7 @@
|
||||||
|
add_library(onnx_mlir_attribute_promotion
|
||||||
|
attribute_promotion.cpp)
|
||||||
|
target_include_directories(onnx_mlir_attribute_promotion
|
||||||
|
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
|
||||||
|
${ONNF_MLIR_SRC_ROOT})
|
||||||
|
target_link_libraries(onnx_mlir_attribute_promotion
|
||||||
|
onnx_mlir_promotable_const_operands_op_interface)
|
|
@ -0,0 +1,92 @@
|
||||||
|
//===----- attribute_promotion.cpp - Attribute Promotion
|
||||||
|
//-------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2020 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file implements a function level pass to move an operand to become
|
||||||
|
// an attribute if desirable and legal.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
#include "src/interface/promotable_const_operands_op_interface.hpp"
|
||||||
|
#include "src/pass/passes.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* Helper function to create a NoneTyped constant value if `none` is empty.
|
||||||
|
*/
|
||||||
|
void getOrCreateNoneValue(llvm::Optional<mlir::Value> &none, FuncOp f) {
|
||||||
|
if (none.hasValue())
|
||||||
|
return;
|
||||||
|
|
||||||
|
OpBuilder builder(f.getContext());
|
||||||
|
builder.setInsertionPointToStart(&f.front());
|
||||||
|
none = builder.create<mlir::ConstantOp>(f.getLoc(), builder.getUnitAttr());
|
||||||
|
}
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* FunctionPass that performs attribute promotion by iterating over a list of
|
||||||
|
* candidate operations and moves constant operands to attributes whenever
|
||||||
|
* desirable (as instructed by the PromotableConstOperandsOpInterface).
|
||||||
|
*/
|
||||||
|
class AttributePromotionPass
|
||||||
|
: public mlir::FunctionPass<AttributePromotionPass> {
|
||||||
|
public:
|
||||||
|
void runOnFunction() override {
|
||||||
|
auto f = getFunction();
|
||||||
|
|
||||||
|
// A function-scope shared none value used to indicate an missing operand.
|
||||||
|
llvm::Optional<mlir::Value> none;
|
||||||
|
|
||||||
|
// Iterate on the operations that may need attribute promotion.
|
||||||
|
f.walk([&](mlir::Operation *op) {
|
||||||
|
if (PromotableConstOperandsOpInterface opWithConstOperand =
|
||||||
|
dyn_cast<PromotableConstOperandsOpInterface>(op)) {
|
||||||
|
auto promotableOperands = opWithConstOperand.promotableConstOperands();
|
||||||
|
for (const auto &operandNameToIdx : promotableOperands) {
|
||||||
|
auto name = operandNameToIdx.first;
|
||||||
|
auto i = operandNameToIdx.second;
|
||||||
|
|
||||||
|
// If the i-th operand is defined by an constant operation, then
|
||||||
|
// move it to an attribute, and use None to indicate the absence
|
||||||
|
// of the original operand value.
|
||||||
|
auto operandToPromote = op->getOperand(i);
|
||||||
|
if (auto constantOp = dyn_cast_or_null<ConstantOp>(
|
||||||
|
operandToPromote.getDefiningOp())) {
|
||||||
|
op->setAttr(name, constantOp.value());
|
||||||
|
getOrCreateNoneValue(none, f);
|
||||||
|
op->setOperand(i, *none);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Dispatch canonicalization pattern rewriters to eliminate redundant
|
||||||
|
// constant operaions.
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
auto *context = &getContext();
|
||||||
|
ConstantOp::getCanonicalizationPatterns(patterns, context);
|
||||||
|
applyPatternsGreedily(f, patterns);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* Create a Attribute Promotion pass.
|
||||||
|
*/
|
||||||
|
std::unique_ptr<mlir::Pass> mlir::createAttributePromotionPass() {
|
||||||
|
return std::make_unique<AttributePromotionPass>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<AttributePromotionPass> pass(
|
||||||
|
"attribute-promotion", "Promote constant operands to attributes.");
|
|
@ -0,0 +1,32 @@
|
||||||
|
// RUN: onnx-mlir-opt --attribute-promotion %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
func @test_should_promote_to_attribute(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%shape = constant dense<[6, 7, 42]> : tensor<3xi32>
|
||||||
|
%0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi32>) -> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
// CHECK-LABEL: test_should_promote_to_attribute
|
||||||
|
// CHECK-NEXT: [[NONE:%.+]] = constant unit
|
||||||
|
// CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi32>} : (tensor<?x10xf32>, none) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_should_not_promote_to_attribute(%arg0 : tensor<?x10xf32>, %arg1 : tensor<*xi64>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<*xi64>) -> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
// CHECK-LABEL: test_should_not_promote_to_attribute
|
||||||
|
// CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, %{{.*}}) : (tensor<?x10xf32>, tensor<*xi64>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_promote_to_attribute_without_removing_const_op(%arg0 : tensor<?x10xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||||
|
%shape = constant dense<[6, 7, 42]> : tensor<3xi32>
|
||||||
|
%0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi32>) -> tensor<*xf32>
|
||||||
|
%1 = "onnx.Identity"(%shape) : (tensor<3xi32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> ()
|
||||||
|
// CHECK-LABEL: test_promote_to_attribute_without_removing_const_op
|
||||||
|
// CHECK-NEXT: [[NONE:%.+]] = constant unit
|
||||||
|
// CHECK-NEXT: [[SHAPE:%.+]] = constant dense<[6, 7, 42]> : tensor<3xi32>
|
||||||
|
// CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi32>} : (tensor<?x10xf32>, none) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: [[IDENTITY:%.+]] = "onnx.Identity"([[SHAPE]]) : (tensor<3xi32>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: return [[RESHAPE]], [[IDENTITY]] : tensor<*xf32>, tensor<*xf32>
|
||||||
|
}
|
Loading…
Reference in New Issue