Support attribute promotion. (#34)
* Support attribute promotion. * Simplify op interface name. * 1. Add more comments to Attribute Promotion Pass. 2. Move Promotable Const Operand Interface to src/interface, and link against it. * Complete NFC change onnx -> onnx-mlir. * Move attribute_promotion pass to src/transform. * Nit: reword comment. * Support Attribute Promotion in gen_doc.py. * Add test. * Update ONNX doc. * Add negative test. * Rename onnxop.inc -> onnx_ops.td.inc. * Include onnx_ops.td.inc. * Nit: better comments. * Prettify CMake. * Remove original attribute_promotion code, improve comments. * Append '_op_interface' to op interface decl/defs. * Namespace cmake targets using onnx_mlir_ prefix. * Use updated header name. * Use new body file name. * Fix dependency. * Use new CMake target name. * Make attribute promotion self-contained by removing redundant constant operaions inside the pass execution. * Remove canonicalization pass. * Increase comments. * Use stricter checks. * Add one more test case. * Remove %arg1 as it's never used.
This commit is contained in:
		
							parent
							
								
									2814ea3898
								
							
						
					
					
						commit
						549af8f0b2
					
				|  | @ -3749,7 +3749,7 @@ ONNX Reshape operation | ||||||
| #### Operands: | #### Operands: | ||||||
| 
 | 
 | ||||||
| 1. `data`: memref of any type values or tensor of any type values | 1. `data`: memref of any type values or tensor of any type values | ||||||
| 1. `shape`: memref of any type values or tensor of any type values | 1. `shape`: memref of any type values or tensor of any type values or none type | ||||||
| 
 | 
 | ||||||
| #### Attributes: | #### Attributes: | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -47,13 +47,20 @@ OpsWithShapeInference = [ | ||||||
|     'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', |     'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', | ||||||
|     'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin', |     'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin', | ||||||
|     'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', |     'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', | ||||||
|     'Sign', 'Constant', 'ONNXAveragePoolOp', 'Abs' |     'Sign', 'Constant', 'AveragePool', 'Abs' | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
| # Operations supporting canonicalization. | # Operations supporting canonicalization. | ||||||
| OpsWithCanonicalizer = [ | OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm'] | ||||||
|     'Add', 'Identity', 'Gemm' | 
 | ||||||
| ] | # Operations who have operands that, if produced by constant operations, should | ||||||
|  | # be promoted to become an attribute (via attribute promotion). | ||||||
|  | # | ||||||
|  | # For each operation, a key/value pair is used to specify how attribute promotion | ||||||
|  | # should proceed. The key is the operation's name and the value is a list of | ||||||
|  | # tuples, whose first item is the attribute/operand name, and the second item is | ||||||
|  | # the index at which such operand occurs in the list of the operation's inputs. | ||||||
|  | OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)]} | ||||||
| 
 | 
 | ||||||
| # Add an Op in this list if the Op needs result type deduction which is required | # Add an Op in this list if the Op needs result type deduction which is required | ||||||
| # when writing declarative rewriting rules. Deduced type is always | # when writing declarative rewriting rules. Deduced type is always | ||||||
|  | @ -224,7 +231,7 @@ def get_operands_or_results(schema, is_input): | ||||||
|             return "AnyTypeOf<[{}]>".format(", ".join(types)) |             return "AnyTypeOf<[{}]>".format(", ".join(types)) | ||||||
| 
 | 
 | ||||||
|     name_to_types = OrderedDict() |     name_to_types = OrderedDict() | ||||||
|     for value in value_list: |     for i, value in enumerate(value_list): | ||||||
|         elem_types = get_allowed_elem_types(schema, value) |         elem_types = get_allowed_elem_types(schema, value) | ||||||
| 
 | 
 | ||||||
|         if elem_types is None: |         if elem_types is None: | ||||||
|  | @ -233,6 +240,13 @@ def get_operands_or_results(schema, is_input): | ||||||
|             types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"] |             types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"] | ||||||
|             types = list(map(lambda x: x.format(elem_types), types)) |             types = list(map(lambda x: x.format(elem_types), types)) | ||||||
| 
 | 
 | ||||||
|  |         # If operand is promotable to an attribute, then it must be | ||||||
|  |         # nullable in case it migrates to be an attribute. | ||||||
|  |         if schema.name in OpsWithPromotableConstOperands: | ||||||
|  |             idxs = dict(OpsWithPromotableConstOperands[schema.name]).values() | ||||||
|  |             if i in idxs: | ||||||
|  |                 types.append("NoneType") | ||||||
|  | 
 | ||||||
|         if OpSchema.FormalParameterOption.Optional == value.option: |         if OpSchema.FormalParameterOption.Optional == value.option: | ||||||
|             types.append("NoneType") |             types.append("NoneType") | ||||||
|         elif OpSchema.FormalParameterOption.Variadic == value.option: |         elif OpSchema.FormalParameterOption.Variadic == value.option: | ||||||
|  | @ -313,6 +327,25 @@ def get_attrs(schema): | ||||||
|     return name_to_type |     return name_to_type | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx): | ||||||
|  |     cpp_name_to_idx_literal = "{" + ", ".join([ | ||||||
|  |         "{{\"{}\", {}}}".format(*name_to_idx) | ||||||
|  |         for name_to_idx in const_operands_name_to_idx | ||||||
|  |     ]) + "}" | ||||||
|  | 
 | ||||||
|  |     s += indent + "let extraClassDeclaration = [{\n" | ||||||
|  |     indent = inc_indent(indent) | ||||||
|  |     s += indent + "std::map<std::string, size_t> promotableConstOperands() {\n" | ||||||
|  |     indent = inc_indent(indent) | ||||||
|  |     s += indent + "return {};\n".format(cpp_name_to_idx_literal) | ||||||
|  |     indent = dec_indent(indent) | ||||||
|  |     s += indent + "}\n" | ||||||
|  |     indent = dec_indent(indent) | ||||||
|  |     s += indent + "}];\n" | ||||||
|  | 
 | ||||||
|  |     return s | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def gen_op_def(schema): | def gen_op_def(schema): | ||||||
|     indent = inc_indent() |     indent = inc_indent() | ||||||
|     s = 'def ONNX{0}Op:ONNX_Op<"{0}",\n'.format(schema.name) |     s = 'def ONNX{0}Op:ONNX_Op<"{0}",\n'.format(schema.name) | ||||||
|  | @ -321,6 +354,8 @@ def gen_op_def(schema): | ||||||
|     traits = ["NoSideEffect"] |     traits = ["NoSideEffect"] | ||||||
|     if schema.name in OpsWithShapeInference: |     if schema.name in OpsWithShapeInference: | ||||||
|         traits.append("DeclareOpInterfaceMethods<ShapeInferenceOpInterface>") |         traits.append("DeclareOpInterfaceMethods<ShapeInferenceOpInterface>") | ||||||
|  |     if schema.name in OpsWithPromotableConstOperands.keys(): | ||||||
|  |         traits.append("OpInterface<\"PromotableConstOperandsOpInterface\">") | ||||||
|     s += inc_indent(indent) + '[{}]> {{\n'.format(join_args(traits)) |     s += inc_indent(indent) + '[{}]> {{\n'.format(join_args(traits)) | ||||||
| 
 | 
 | ||||||
|     # Generate decl for canonicalizer. |     # Generate decl for canonicalizer. | ||||||
|  | @ -400,6 +435,9 @@ def gen_op_def(schema): | ||||||
| 
 | 
 | ||||||
|             s += '\n' + indent + '];\n' |             s += '\n' + indent + '];\n' | ||||||
| 
 | 
 | ||||||
|  |     if schema.name in OpsWithPromotableConstOperands: | ||||||
|  |         s = get_promotable_const_operands_func( | ||||||
|  |             s, indent, OpsWithPromotableConstOperands[schema.name]) | ||||||
|     s += '}\n\n' |     s += '}\n\n' | ||||||
|     return s |     return s | ||||||
| 
 | 
 | ||||||
|  | @ -506,7 +544,7 @@ if __name__ == '__main__': | ||||||
|     curr_dir = os.path.dirname(os.path.realpath(__file__)) |     curr_dir = os.path.dirname(os.path.realpath(__file__)) | ||||||
| 
 | 
 | ||||||
|     class Args(object): |     class Args(object): | ||||||
|         op_def_file = os.path.join(curr_dir, 'onnxop.inc') |         op_def_file = os.path.join(curr_dir, 'onnx_ops.td.inc') | ||||||
|         op_importer_file = os.path.join(curr_dir, 'op_build_table.inc') |         op_importer_file = os.path.join(curr_dir, 'op_build_table.inc') | ||||||
| 
 | 
 | ||||||
|     main(Args) |     main(Args) | ||||||
|  |  | ||||||
|  | @ -8,7 +8,6 @@ add_library(compiler | ||||||
|         dialect/krnl/krnl_helper.cpp |         dialect/krnl/krnl_helper.cpp | ||||||
|         dialect/krnl/krnl_helper.hpp |         dialect/krnl/krnl_helper.hpp | ||||||
|         pass/shape_inference_interface.hpp |         pass/shape_inference_interface.hpp | ||||||
|         dialect/onnx/onnxop.inc |  | ||||||
|         pass/onnx_combine.cpp |         pass/onnx_combine.cpp | ||||||
|         pass/onnx_rewrite.cpp |         pass/onnx_rewrite.cpp | ||||||
|         pass/onnx_decompose.cpp |         pass/onnx_decompose.cpp | ||||||
|  | @ -47,12 +46,22 @@ onnx_mlir_tablegen(onnx_rewrite.inc -gen-rewriters) | ||||||
| add_public_tablegen_target(gen_onnx_rewrite) | add_public_tablegen_target(gen_onnx_rewrite) | ||||||
| add_dependencies(compiler gen_onnx_rewrite) | add_dependencies(compiler gen_onnx_rewrite) | ||||||
| 
 | 
 | ||||||
|  | add_subdirectory(interface) | ||||||
|  | 
 | ||||||
| set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td) | set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td) | ||||||
| onnx_mlir_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass") | onnx_mlir_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass") | ||||||
| onnx_mlir_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass") | onnx_mlir_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass") | ||||||
| set(GEN_DOC_FILE ${CMAKE_BINARY_DIR}/docs/Dialects/onnx.md) | set(GEN_DOC_FILE ${CMAKE_BINARY_DIR}/docs/Dialects/onnx.md) | ||||||
| add_public_tablegen_target(gen_onnx) | add_public_tablegen_target(gen_onnx) | ||||||
|  | 
 | ||||||
|  | add_dependencies(gen_onnx gen_shape_inference) | ||||||
| add_dependencies(compiler gen_onnx) | add_dependencies(compiler gen_onnx) | ||||||
|  | 
 | ||||||
|  | # TODO: onnx_mlir_gen_promotable_const_operands_op_interface is really a | ||||||
|  | # dependency of the onnx dialect library, which is currently part of `compiler`. | ||||||
|  | add_dependencies(compiler onnx_mlir_gen_promotable_const_operands_op_interface) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| add_onnx_mlir_dialect_doc(onnx dialect/onnx/onnx.td) | add_onnx_mlir_dialect_doc(onnx dialect/onnx/onnx.td) | ||||||
| 
 | 
 | ||||||
| set(LLVM_TARGET_DEFINITIONS dialect/krnl/krnl_ops.td) | set(LLVM_TARGET_DEFINITIONS dialect/krnl/krnl_ops.td) | ||||||
|  | @ -66,14 +75,14 @@ target_include_directories(onnx_mlir_onnx_decompose | ||||||
|         PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} |         PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} | ||||||
|         ${ONNX_MLIR_SRC_ROOT}) |         ${ONNX_MLIR_SRC_ROOT}) | ||||||
| target_link_libraries(onnx_mlir_onnx_decompose ${MLIRLibs}) | target_link_libraries(onnx_mlir_onnx_decompose ${MLIRLibs}) | ||||||
| add_dependencies(onnx_mlir_onnx_decompose gen_krnl_ops) | add_dependencies(onnx_mlir_onnx_decompose gen_onnx) | ||||||
| 
 | 
 | ||||||
| add_library(onnx_mlir_shape_inference pass/shape_inference_pass.cpp) | add_library(onnx_mlir_shape_inference pass/shape_inference_pass.cpp) | ||||||
| target_include_directories(onnx_mlir_shape_inference | target_include_directories(onnx_mlir_shape_inference | ||||||
|         PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} |         PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} | ||||||
|         ${ONNX_MLIR_SRC_ROOT}) |         ${ONNX_MLIR_SRC_ROOT}) | ||||||
| target_link_libraries(onnx_mlir_shape_inference ${MLIRLibs}) | target_link_libraries(onnx_mlir_shape_inference ${MLIRLibs}) | ||||||
| add_dependencies(onnx_mlir_shape_inference gen_krnl_ops) | add_dependencies(onnx_mlir_shape_inference gen_onnx) | ||||||
| 
 | 
 | ||||||
| add_library(onnx_mlir_lower_frontend | add_library(onnx_mlir_lower_frontend | ||||||
|         conversion/onnx_to_krnl/onnx_to_krnl_common.cpp |         conversion/onnx_to_krnl/onnx_to_krnl_common.cpp | ||||||
|  | @ -106,8 +115,17 @@ add_subdirectory(runtime) | ||||||
| 
 | 
 | ||||||
| add_executable(onnx-mlir main.cpp) | add_executable(onnx-mlir main.cpp) | ||||||
| 
 | 
 | ||||||
| target_link_libraries(onnx-mlir builder ${MLIRLibs} onnx_mlir_transform onnx_mlir_onnx_decompose onnx_mlir_shape_inference onnx_mlir_lower_frontend) | target_link_libraries(onnx-mlir | ||||||
| whole_archive_link_mlir(onnx-mlir ${MLIRWholeArchiveLibs}) |         builder | ||||||
|  |         ${MLIRLibs} | ||||||
|  |         onnx_mlir_transform | ||||||
|  |         onnx_mlir_onnx_decompose | ||||||
|  |         onnx_mlir_shape_inference | ||||||
|  |         onnx_mlir_lower_frontend | ||||||
|  |         onnx_mlir_attribute_promotion) | ||||||
|  | whole_archive_link_mlir(onnx-mlir | ||||||
|  |         ${MLIRWholeArchiveLibs}) | ||||||
|  | 
 | ||||||
| find_package(ZLIB REQUIRED) | find_package(ZLIB REQUIRED) | ||||||
| target_link_libraries(onnx-mlir ${ZLIB_LIBRARIES}) | target_link_libraries(onnx-mlir ${ZLIB_LIBRARIES}) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -22,6 +22,11 @@ include "mlir/IR/OpBase.td" | ||||||
| include "pass/shape_inference_interface.td" | include "pass/shape_inference_interface.td" | ||||||
| #endif // SHAPE_INFERENCE_INTERFACE | #endif // SHAPE_INFERENCE_INTERFACE | ||||||
| 
 | 
 | ||||||
|  | #ifdef PROMOTABLE_CONST_OPERANDS_OP_INTERFACE | ||||||
|  | #else | ||||||
|  | include "interface/promotable_const_operands_op_interface.td" | ||||||
|  | #endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE | ||||||
|  | 
 | ||||||
| def ONNX_Dialect : Dialect { | def ONNX_Dialect : Dialect { | ||||||
|   let name = "onnx"; |   let name = "onnx"; | ||||||
|   let cppNamespace = ""; |   let cppNamespace = ""; | ||||||
|  | @ -48,7 +53,7 @@ class ONNX_Op<string mnemonic, list<OpTrait> traits = []> : | ||||||
| //  CC=gcc CXX=g++ pip install -e . | //  CC=gcc CXX=g++ pip install -e . | ||||||
| //run the script | //run the script | ||||||
| //  python onnx/defs/gen_doc.py | //  python onnx/defs/gen_doc.py | ||||||
| //result is in docs/onnxop.inc | //result is in docs/onnx_ops.td.inc | ||||||
| //current limitations: | //current limitations: | ||||||
| // 1. Attributes are not processed | // 1. Attributes are not processed | ||||||
| // 2. output type inference not implemented except Add | // 2. output type inference not implemented except Add | ||||||
|  | @ -56,7 +61,7 @@ class ONNX_Op<string mnemonic, list<OpTrait> traits = []> : | ||||||
| // 4. type of string, complex64 and complex128 for input/output are ignored  | // 4. type of string, complex64 and complex128 for input/output are ignored  | ||||||
| // 5. unsigned int are treated as signed one | // 5. unsigned int are treated as signed one | ||||||
| 
 | 
 | ||||||
| include "dialect/onnx/onnxop.inc" | include "dialect/onnx/onnx_ops.td.inc" | ||||||
| 
 | 
 | ||||||
| // Indicate entry point functions of ONNX graph. | // Indicate entry point functions of ONNX graph. | ||||||
| def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> { | def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> { | ||||||
|  |  | ||||||
|  | @ -10,6 +10,9 @@ | ||||||
| 
 | 
 | ||||||
| #pragma once | #pragma once | ||||||
| 
 | 
 | ||||||
|  | #include <map> | ||||||
|  | #include <string> | ||||||
|  | 
 | ||||||
| #include "mlir/Dialect/StandardOps/Ops.h" | #include "mlir/Dialect/StandardOps/Ops.h" | ||||||
| #include "mlir/IR/Builders.h" | #include "mlir/IR/Builders.h" | ||||||
| #include "mlir/IR/Dialect.h" | #include "mlir/IR/Dialect.h" | ||||||
|  | @ -17,6 +20,7 @@ | ||||||
| #include "mlir/IR/StandardTypes.h" | #include "mlir/IR/StandardTypes.h" | ||||||
| 
 | 
 | ||||||
| #include "src/pass/shape_inference_interface.hpp" | #include "src/pass/shape_inference_interface.hpp" | ||||||
|  | #include "src/interface/promotable_const_operands_op_interface.hpp" | ||||||
| 
 | 
 | ||||||
| namespace mlir { | namespace mlir { | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,5 +1,5 @@ | ||||||
| //********************************************************
 | //********************************************************
 | ||||||
| //   This file is generated on UTC-02/24/2020, 06:44:13.
 | //   This file is generated on UTC-03/18/2020, 03:36:58.
 | ||||||
| //   Do not modify this file directly.
 | //   Do not modify this file directly.
 | ||||||
| //   This file is automatically generated via script.
 | //   This file is automatically generated via script.
 | ||||||
| //   Details can be found in doc/readonnxdefs.md .
 | //   Details can be found in doc/readonnxdefs.md .
 | ||||||
|  | @ -2501,7 +2501,7 @@ def ONNXReluOp:ONNX_Op<"Relu", | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| def ONNXReshapeOp:ONNX_Op<"Reshape", | def ONNXReshapeOp:ONNX_Op<"Reshape", | ||||||
|   [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { |   [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"PromotableConstOperandsOpInterface">]> { | ||||||
|   let summary = "ONNX Reshape operation"; |   let summary = "ONNX Reshape operation"; | ||||||
|   let description = [{ |   let description = [{ | ||||||
|   "Reshape the input tensor similar to numpy.reshape." |   "Reshape the input tensor similar to numpy.reshape." | ||||||
|  | @ -2512,8 +2512,13 @@ def ONNXReshapeOp:ONNX_Op<"Reshape", | ||||||
|   "from the input tensor)." |   "from the input tensor)." | ||||||
|   }]; |   }]; | ||||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, |   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, | ||||||
|     AnyTypeOf<[AnyMemRef, AnyTensor]>:$shape); |     AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$shape); | ||||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reshaped); |   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reshaped); | ||||||
|  |   let extraClassDeclaration = [{ | ||||||
|  |     std::map<std::string, size_t> promotableConstOperands() { | ||||||
|  |       return {{"shape", 1}}; | ||||||
|  |     } | ||||||
|  |   }]; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| def ONNXResizeOp:ONNX_Op<"Resize", | def ONNXResizeOp:ONNX_Op<"Resize", | ||||||
|  | @ -0,0 +1,14 @@ | ||||||
|  | set(LLVM_TARGET_DEFINITIONS promotable_const_operands_op_interface.td) | ||||||
|  | onnx_mlir_tablegen(promotable_const_operands_op_interface.hpp.inc -gen-op-interface-decls) | ||||||
|  | onnx_mlir_tablegen(promotable_const_operands_op_interface.cpp.inc -gen-op-interface-defs) | ||||||
|  | add_public_tablegen_target(onnx_mlir_gen_promotable_const_operands_op_interface) | ||||||
|  | 
 | ||||||
|  | add_library(onnx_mlir_promotable_const_operands_op_interface | ||||||
|  |         promotable_const_operands_op_interface.hpp | ||||||
|  |         promotable_const_operands_op_interface.cpp) | ||||||
|  | target_include_directories(onnx_mlir_promotable_const_operands_op_interface | ||||||
|  |         PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} | ||||||
|  |         ${ONNX_MLIR_SRC_ROOT}) | ||||||
|  | 
 | ||||||
|  | add_dependencies(onnx_mlir_promotable_const_operands_op_interface | ||||||
|  |         onnx_mlir_gen_promotable_const_operands_op_interface) | ||||||
|  | @ -0,0 +1,23 @@ | ||||||
|  | //===------------ promotable_const_operands_op_interface.cpp --------------===//
 | ||||||
|  | //===-------- Promotable Const Operands Op Interface Definition -----------===//
 | ||||||
|  | //
 | ||||||
|  | // Copyright 2020 The IBM Research Authors.
 | ||||||
|  | //
 | ||||||
|  | // =============================================================================
 | ||||||
|  | //
 | ||||||
|  | // This file contains the definition of the promotable const operands op
 | ||||||
|  | // interface.
 | ||||||
|  | //
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | #include "src/interface/promotable_const_operands_op_interface.hpp" | ||||||
|  | 
 | ||||||
|  | using namespace mlir; | ||||||
|  | 
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | // Promotable Const Operands Op Interface
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | 
 | ||||||
|  | #include "src/interface/promotable_const_operands_op_interface.cpp.inc" | ||||||
|  | 
 | ||||||
|  | @ -0,0 +1,25 @@ | ||||||
|  | //===------------ promotable_const_operands_op_interface.cpp --------------===//
 | ||||||
|  | //===-------- Promotable Const Operands Op Interface Definition -----------===//
 | ||||||
|  | //
 | ||||||
|  | // Copyright 2020 The IBM Research Authors.
 | ||||||
|  | //
 | ||||||
|  | // =============================================================================
 | ||||||
|  | //
 | ||||||
|  | // This file contains the declaration of the promotable const operands op
 | ||||||
|  | // interface.
 | ||||||
|  | //
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | 
 | ||||||
|  | #pragma once | ||||||
|  | 
 | ||||||
|  | #include <map> | ||||||
|  | #include <string> | ||||||
|  | 
 | ||||||
|  | #include "mlir/IR/OpDefinition.h" | ||||||
|  | 
 | ||||||
|  | namespace mlir { | ||||||
|  | 
 | ||||||
|  | /// Include the auto-generated declarations.
 | ||||||
|  | #include "src/interface/promotable_const_operands_op_interface.hpp.inc" | ||||||
|  | 
 | ||||||
|  | }  // end namespace mlir
 | ||||||
|  | @ -0,0 +1,34 @@ | ||||||
|  | //===------------- promotable_const_operands_op_interface.td --------------===// | ||||||
|  | //===---- Promotable Const Operands Op Interface TableGen Definition ------===// | ||||||
|  | // | ||||||
|  | // Copyright 2020 The IBM Research Authors. | ||||||
|  | // | ||||||
|  | // ============================================================================= | ||||||
|  | // | ||||||
|  | // This file contains the TableGen definition of the promotable const operands | ||||||
|  | // op interface. | ||||||
|  | // | ||||||
|  | //===----------------------------------------------------------------------===// | ||||||
|  | 
 | ||||||
|  | #ifdef PROMOTABLE_CONST_OPERANDS_OP_INTERFACE | ||||||
|  | #else | ||||||
|  | #define PROMOTABLE_CONST_OPERANDS_OP_INTERFACE | ||||||
|  | 
 | ||||||
|  | #ifdef OP_BASE | ||||||
|  | #else | ||||||
|  | include "mlir/IR/OpBase.td" | ||||||
|  | #endif // OP_BASE | ||||||
|  | 
 | ||||||
|  | def PromotableConstOperandsOpInterface : OpInterface<"PromotableConstOperandsOpInterface"> { | ||||||
|  |   let description = [{ | ||||||
|  |     Interface to access a registered method to infer the return types for an | ||||||
|  |     operation that can be used during type inference. | ||||||
|  |   }]; | ||||||
|  | 
 | ||||||
|  |   let methods = [ | ||||||
|  |     InterfaceMethod<"Infer and set the output shape for the current operation.", | ||||||
|  |                     "std::map<std::string, size_t>", "promotableConstOperands"> | ||||||
|  |   ]; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE | ||||||
|  | @ -126,6 +126,7 @@ int main(int argc, char *argv[]) { | ||||||
|   pm.addPass(mlir::createShapeInferencePass()); |   pm.addPass(mlir::createShapeInferencePass()); | ||||||
|   pm.addPass(mlir::createCanonicalizerPass()); |   pm.addPass(mlir::createCanonicalizerPass()); | ||||||
|   pm.addPass(mlir::createShapeInferencePass()); |   pm.addPass(mlir::createShapeInferencePass()); | ||||||
|  |   pm.addPass(mlir::createAttributePromotionPass()); | ||||||
| 
 | 
 | ||||||
|   if (emissionTarget >= EmitMLIR) { |   if (emissionTarget >= EmitMLIR) { | ||||||
|     pm.addPass(mlir::createLowerToKrnlPass()); |     pm.addPass(mlir::createLowerToKrnlPass()); | ||||||
|  |  | ||||||
|  | @ -20,6 +20,9 @@ std::unique_ptr<Pass> createDecomposeONNXToONNXPass(); | ||||||
| 
 | 
 | ||||||
| std::unique_ptr<Pass> createShapeInferencePass(); | std::unique_ptr<Pass> createShapeInferencePass(); | ||||||
| 
 | 
 | ||||||
|  | /// Pass for promoting constant operands to attributes.
 | ||||||
|  | std::unique_ptr<Pass> createAttributePromotionPass(); | ||||||
|  | 
 | ||||||
| /// Add pass for lowering to Krnl IR.
 | /// Add pass for lowering to Krnl IR.
 | ||||||
| std::unique_ptr<Pass> createLowerToKrnlPass(); | std::unique_ptr<Pass> createLowerToKrnlPass(); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -12,9 +12,9 @@ | ||||||
| #include "mlir/Pass/Pass.h" | #include "mlir/Pass/Pass.h" | ||||||
| #include "llvm/ADT/SmallPtrSet.h" | #include "llvm/ADT/SmallPtrSet.h" | ||||||
| #include "llvm/Support/raw_ostream.h" | #include "llvm/Support/raw_ostream.h" | ||||||
|  | #include "mlir/IR/StandardTypes.h" | ||||||
| 
 | 
 | ||||||
| #include "shape_inference_interface.hpp" | #include "shape_inference_interface.hpp" | ||||||
| #include "src/dialect/onnx/onnx_ops.hpp" |  | ||||||
| 
 | 
 | ||||||
| #include "passes.hpp" | #include "passes.hpp" | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -4,6 +4,19 @@ add_dependencies(onnx-mlir-opt gen_krnl_ops) | ||||||
| target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_SRC_ROOT}) | target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_SRC_ROOT}) | ||||||
| target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_BIN_ROOT}) | target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_BIN_ROOT}) | ||||||
| 
 | 
 | ||||||
| target_link_libraries(onnx-mlir-opt builder ${MLIRLibs} onnx_mlir_transform onnx_mlir_shape_inference onnx_mlir_lower_frontend curses) | target_link_libraries(onnx-mlir-opt | ||||||
| whole_archive_link_mlir(onnx-mlir-opt ${MLIRWholeArchiveLibs}) |         builder | ||||||
| whole_archive_link_onnx_mlir(onnx-mlir-opt compiler onnx_mlir_transform onnx_mlir_lower_frontend onnx_mlir_shape_inference) |         ${MLIRLibs} | ||||||
|  |         onnx_mlir_transform | ||||||
|  |         onnx_mlir_shape_inference | ||||||
|  |         onnx_mlir_lower_frontend | ||||||
|  |         onnx_mlir_promotable_const_operands_op_interface | ||||||
|  |         curses) | ||||||
|  | whole_archive_link_mlir(onnx-mlir-opt | ||||||
|  |         ${MLIRWholeArchiveLibs}) | ||||||
|  | whole_archive_link_onnx_mlir(onnx-mlir-opt | ||||||
|  |         compiler | ||||||
|  |         onnx_mlir_transform | ||||||
|  |         onnx_mlir_lower_frontend | ||||||
|  |         onnx_mlir_shape_inference | ||||||
|  |         onnx_mlir_attribute_promotion) | ||||||
|  | @ -7,3 +7,5 @@ target_include_directories(onnx_mlir_transform | ||||||
|                                    ${ONNX_MLIR_SRC_ROOT}) |                                    ${ONNX_MLIR_SRC_ROOT}) | ||||||
| target_link_libraries(onnx_mlir_transform ${MLIRLibs}) | target_link_libraries(onnx_mlir_transform ${MLIRLibs}) | ||||||
| add_dependencies(onnx_mlir_transform gen_krnl_ops) | add_dependencies(onnx_mlir_transform gen_krnl_ops) | ||||||
|  | 
 | ||||||
|  | add_subdirectory(onnx) | ||||||
|  | @ -0,0 +1,7 @@ | ||||||
|  | add_library(onnx_mlir_attribute_promotion | ||||||
|  |         attribute_promotion.cpp) | ||||||
|  | target_include_directories(onnx_mlir_attribute_promotion | ||||||
|  |         PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} | ||||||
|  |         ${ONNF_MLIR_SRC_ROOT}) | ||||||
|  | target_link_libraries(onnx_mlir_attribute_promotion | ||||||
|  |         onnx_mlir_promotable_const_operands_op_interface) | ||||||
|  | @ -0,0 +1,92 @@ | ||||||
|  | //===----- attribute_promotion.cpp - Attribute Promotion
 | ||||||
|  | //-------------------===//
 | ||||||
|  | //
 | ||||||
|  | // Copyright 2020 The IBM Research Authors.
 | ||||||
|  | //
 | ||||||
|  | // =============================================================================
 | ||||||
|  | //
 | ||||||
|  | // This file implements a function level pass to move an operand to become
 | ||||||
|  | // an attribute if desirable and legal.
 | ||||||
|  | //
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | 
 | ||||||
|  | #include "mlir/Dialect/StandardOps/Ops.h" | ||||||
|  | #include "mlir/IR/Builders.h" | ||||||
|  | #include "mlir/IR/PatternMatch.h" | ||||||
|  | #include "mlir/Pass/Pass.h" | ||||||
|  | 
 | ||||||
|  | #include "src/interface/promotable_const_operands_op_interface.hpp" | ||||||
|  | #include "src/pass/passes.hpp" | ||||||
|  | 
 | ||||||
|  | using namespace mlir; | ||||||
|  | 
 | ||||||
|  | namespace { | ||||||
|  | 
 | ||||||
|  | /*!
 | ||||||
|  |  * Helper function to create a NoneTyped constant value if `none` is empty. | ||||||
|  |  */ | ||||||
|  | void getOrCreateNoneValue(llvm::Optional<mlir::Value> &none, FuncOp f) { | ||||||
|  |   if (none.hasValue()) | ||||||
|  |     return; | ||||||
|  | 
 | ||||||
|  |   OpBuilder builder(f.getContext()); | ||||||
|  |   builder.setInsertionPointToStart(&f.front()); | ||||||
|  |   none = builder.create<mlir::ConstantOp>(f.getLoc(), builder.getUnitAttr()); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /*!
 | ||||||
|  |  *  FunctionPass that performs attribute promotion by iterating over a list of | ||||||
|  |  *  candidate operations and moves constant operands to attributes whenever | ||||||
|  |  *  desirable (as instructed by the PromotableConstOperandsOpInterface). | ||||||
|  |  */ | ||||||
|  | class AttributePromotionPass | ||||||
|  |     : public mlir::FunctionPass<AttributePromotionPass> { | ||||||
|  | public: | ||||||
|  |   void runOnFunction() override { | ||||||
|  |     auto f = getFunction(); | ||||||
|  | 
 | ||||||
|  |     // A function-scope shared none value used to indicate an missing operand.
 | ||||||
|  |     llvm::Optional<mlir::Value> none; | ||||||
|  | 
 | ||||||
|  |     // Iterate on the operations that may need attribute promotion.
 | ||||||
|  |     f.walk([&](mlir::Operation *op) { | ||||||
|  |       if (PromotableConstOperandsOpInterface opWithConstOperand = | ||||||
|  |               dyn_cast<PromotableConstOperandsOpInterface>(op)) { | ||||||
|  |         auto promotableOperands = opWithConstOperand.promotableConstOperands(); | ||||||
|  |         for (const auto &operandNameToIdx : promotableOperands) { | ||||||
|  |           auto name = operandNameToIdx.first; | ||||||
|  |           auto i = operandNameToIdx.second; | ||||||
|  | 
 | ||||||
|  |           // If the i-th operand is defined by an constant operation, then
 | ||||||
|  |           // move it to an attribute, and use None to indicate the absence
 | ||||||
|  |           // of the original operand value.
 | ||||||
|  |           auto operandToPromote = op->getOperand(i); | ||||||
|  |           if (auto constantOp = dyn_cast_or_null<ConstantOp>( | ||||||
|  |                   operandToPromote.getDefiningOp())) { | ||||||
|  |             op->setAttr(name, constantOp.value()); | ||||||
|  |             getOrCreateNoneValue(none, f); | ||||||
|  |             op->setOperand(i, *none); | ||||||
|  |           } | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     }); | ||||||
|  | 
 | ||||||
|  |     // Dispatch canonicalization pattern rewriters to eliminate redundant
 | ||||||
|  |     // constant operaions.
 | ||||||
|  |     OwningRewritePatternList patterns; | ||||||
|  |     auto *context = &getContext(); | ||||||
|  |     ConstantOp::getCanonicalizationPatterns(patterns, context); | ||||||
|  |     applyPatternsGreedily(f, patterns); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | } // end anonymous namespace
 | ||||||
|  | 
 | ||||||
|  | /*!
 | ||||||
|  |  * Create a Attribute Promotion pass. | ||||||
|  |  */ | ||||||
|  | std::unique_ptr<mlir::Pass> mlir::createAttributePromotionPass() { | ||||||
|  |   return std::make_unique<AttributePromotionPass>(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | static PassRegistration<AttributePromotionPass> pass( | ||||||
|  |     "attribute-promotion", "Promote constant operands to attributes."); | ||||||
|  | @ -0,0 +1,32 @@ | ||||||
|  | // RUN: onnx-mlir-opt --attribute-promotion %s -split-input-file | FileCheck %s | ||||||
|  | 
 | ||||||
|  | func @test_should_promote_to_attribute(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { | ||||||
|  |   %shape = constant dense<[6, 7, 42]> : tensor<3xi32> | ||||||
|  |   %0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi32>) -> tensor<*xf32> | ||||||
|  |   return %0 : tensor<*xf32> | ||||||
|  |   // CHECK-LABEL: test_should_promote_to_attribute | ||||||
|  |   // CHECK-NEXT: [[NONE:%.+]] = constant unit | ||||||
|  |   // CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi32>} : (tensor<?x10xf32>, none) -> tensor<*xf32> | ||||||
|  |   // CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func @test_should_not_promote_to_attribute(%arg0 : tensor<?x10xf32>, %arg1 : tensor<*xi64>) -> tensor<*xf32> { | ||||||
|  |   %0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<*xi64>) -> tensor<*xf32> | ||||||
|  |   return %0 : tensor<*xf32> | ||||||
|  |   // CHECK-LABEL: test_should_not_promote_to_attribute | ||||||
|  |   // CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, %{{.*}}) : (tensor<?x10xf32>, tensor<*xi64>) -> tensor<*xf32> | ||||||
|  |   // CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func @test_promote_to_attribute_without_removing_const_op(%arg0 : tensor<?x10xf32>) -> (tensor<*xf32>, tensor<*xf32>) { | ||||||
|  |   %shape = constant dense<[6, 7, 42]> : tensor<3xi32> | ||||||
|  |   %0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi32>) -> tensor<*xf32> | ||||||
|  |   %1 = "onnx.Identity"(%shape) : (tensor<3xi32>) -> tensor<*xf32> | ||||||
|  |   "std.return"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> () | ||||||
|  |   // CHECK-LABEL: test_promote_to_attribute_without_removing_const_op | ||||||
|  |   // CHECK-NEXT: [[NONE:%.+]] = constant unit | ||||||
|  |   // CHECK-NEXT: [[SHAPE:%.+]] = constant dense<[6, 7, 42]> : tensor<3xi32> | ||||||
|  |   // CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi32>} : (tensor<?x10xf32>, none) -> tensor<*xf32> | ||||||
|  |   // CHECK-NEXT: [[IDENTITY:%.+]] = "onnx.Identity"([[SHAPE]]) : (tensor<3xi32>) -> tensor<*xf32> | ||||||
|  |   // CHECK-NEXT: return [[RESHAPE]], [[IDENTITY]] : tensor<*xf32>, tensor<*xf32> | ||||||
|  | } | ||||||
		Loading…
	
		Reference in New Issue