Support attribute promotion. (#34)
* Support attribute promotion. * Simplify op interface name. * 1. Add more comments to Attribute Promotion Pass. 2. Move Promotable Const Operand Interface to src/interface, and link against it. * Complete NFC change onnx -> onnx-mlir. * Move attribute_promotion pass to src/transform. * Nit: reword comment. * Support Attribute Promotion in gen_doc.py. * Add test. * Update ONNX doc. * Add negative test. * Rename onnxop.inc -> onnx_ops.td.inc. * Include onnx_ops.td.inc. * Nit: better comments. * Prettify CMake. * Remove original attribute_promotion code, improve comments. * Append '_op_interface' to op interface decl/defs. * Namespace cmake targets using onnx_mlir_ prefix. * Use updated header name. * Use new body file name. * Fix dependency. * Use new CMake target name. * Make attribute promotion self-contained by removing redundant constant operaions inside the pass execution. * Remove canonicalization pass. * Increase comments. * Use stricter checks. * Add one more test case. * Remove %arg1 as it's never used.
This commit is contained in:
		
							parent
							
								
									2814ea3898
								
							
						
					
					
						commit
						549af8f0b2
					
				|  | @ -3749,7 +3749,7 @@ ONNX Reshape operation | |||
| #### Operands: | ||||
| 
 | ||||
| 1. `data`: memref of any type values or tensor of any type values | ||||
| 1. `shape`: memref of any type values or tensor of any type values | ||||
| 1. `shape`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| #### Attributes: | ||||
| 
 | ||||
|  |  | |||
|  | @ -47,13 +47,20 @@ OpsWithShapeInference = [ | |||
|     'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', | ||||
|     'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin', | ||||
|     'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', | ||||
|     'Sign', 'Constant', 'ONNXAveragePoolOp', 'Abs' | ||||
|     'Sign', 'Constant', 'AveragePool', 'Abs' | ||||
| ] | ||||
| 
 | ||||
| # Operations supporting canonicalization. | ||||
| OpsWithCanonicalizer = [ | ||||
|     'Add', 'Identity', 'Gemm' | ||||
| ] | ||||
| OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm'] | ||||
| 
 | ||||
| # Operations who have operands that, if produced by constant operations, should | ||||
| # be promoted to become an attribute (via attribute promotion). | ||||
| # | ||||
| # For each operation, a key/value pair is used to specify how attribute promotion | ||||
| # should proceed. The key is the operation's name and the value is a list of | ||||
| # tuples, whose first item is the attribute/operand name, and the second item is | ||||
| # the index at which such operand occurs in the list of the operation's inputs. | ||||
| OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)]} | ||||
| 
 | ||||
| # Add an Op in this list if the Op needs result type deduction which is required | ||||
| # when writing declarative rewriting rules. Deduced type is always | ||||
|  | @ -224,7 +231,7 @@ def get_operands_or_results(schema, is_input): | |||
|             return "AnyTypeOf<[{}]>".format(", ".join(types)) | ||||
| 
 | ||||
|     name_to_types = OrderedDict() | ||||
|     for value in value_list: | ||||
|     for i, value in enumerate(value_list): | ||||
|         elem_types = get_allowed_elem_types(schema, value) | ||||
| 
 | ||||
|         if elem_types is None: | ||||
|  | @ -233,6 +240,13 @@ def get_operands_or_results(schema, is_input): | |||
|             types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"] | ||||
|             types = list(map(lambda x: x.format(elem_types), types)) | ||||
| 
 | ||||
|         # If operand is promotable to an attribute, then it must be | ||||
|         # nullable in case it migrates to be an attribute. | ||||
|         if schema.name in OpsWithPromotableConstOperands: | ||||
|             idxs = dict(OpsWithPromotableConstOperands[schema.name]).values() | ||||
|             if i in idxs: | ||||
|                 types.append("NoneType") | ||||
| 
 | ||||
|         if OpSchema.FormalParameterOption.Optional == value.option: | ||||
|             types.append("NoneType") | ||||
|         elif OpSchema.FormalParameterOption.Variadic == value.option: | ||||
|  | @ -313,6 +327,25 @@ def get_attrs(schema): | |||
|     return name_to_type | ||||
| 
 | ||||
| 
 | ||||
| def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx): | ||||
|     cpp_name_to_idx_literal = "{" + ", ".join([ | ||||
|         "{{\"{}\", {}}}".format(*name_to_idx) | ||||
|         for name_to_idx in const_operands_name_to_idx | ||||
|     ]) + "}" | ||||
| 
 | ||||
|     s += indent + "let extraClassDeclaration = [{\n" | ||||
|     indent = inc_indent(indent) | ||||
|     s += indent + "std::map<std::string, size_t> promotableConstOperands() {\n" | ||||
|     indent = inc_indent(indent) | ||||
|     s += indent + "return {};\n".format(cpp_name_to_idx_literal) | ||||
|     indent = dec_indent(indent) | ||||
|     s += indent + "}\n" | ||||
|     indent = dec_indent(indent) | ||||
|     s += indent + "}];\n" | ||||
| 
 | ||||
|     return s | ||||
| 
 | ||||
| 
 | ||||
| def gen_op_def(schema): | ||||
|     indent = inc_indent() | ||||
|     s = 'def ONNX{0}Op:ONNX_Op<"{0}",\n'.format(schema.name) | ||||
|  | @ -321,6 +354,8 @@ def gen_op_def(schema): | |||
|     traits = ["NoSideEffect"] | ||||
|     if schema.name in OpsWithShapeInference: | ||||
|         traits.append("DeclareOpInterfaceMethods<ShapeInferenceOpInterface>") | ||||
|     if schema.name in OpsWithPromotableConstOperands.keys(): | ||||
|         traits.append("OpInterface<\"PromotableConstOperandsOpInterface\">") | ||||
|     s += inc_indent(indent) + '[{}]> {{\n'.format(join_args(traits)) | ||||
| 
 | ||||
|     # Generate decl for canonicalizer. | ||||
|  | @ -400,6 +435,9 @@ def gen_op_def(schema): | |||
| 
 | ||||
|             s += '\n' + indent + '];\n' | ||||
| 
 | ||||
|     if schema.name in OpsWithPromotableConstOperands: | ||||
|         s = get_promotable_const_operands_func( | ||||
|             s, indent, OpsWithPromotableConstOperands[schema.name]) | ||||
|     s += '}\n\n' | ||||
|     return s | ||||
| 
 | ||||
|  | @ -506,7 +544,7 @@ if __name__ == '__main__': | |||
|     curr_dir = os.path.dirname(os.path.realpath(__file__)) | ||||
| 
 | ||||
|     class Args(object): | ||||
|         op_def_file = os.path.join(curr_dir, 'onnxop.inc') | ||||
|         op_def_file = os.path.join(curr_dir, 'onnx_ops.td.inc') | ||||
|         op_importer_file = os.path.join(curr_dir, 'op_build_table.inc') | ||||
| 
 | ||||
|     main(Args) | ||||
|  |  | |||
|  | @ -8,7 +8,6 @@ add_library(compiler | |||
|         dialect/krnl/krnl_helper.cpp | ||||
|         dialect/krnl/krnl_helper.hpp | ||||
|         pass/shape_inference_interface.hpp | ||||
|         dialect/onnx/onnxop.inc | ||||
|         pass/onnx_combine.cpp | ||||
|         pass/onnx_rewrite.cpp | ||||
|         pass/onnx_decompose.cpp | ||||
|  | @ -47,12 +46,22 @@ onnx_mlir_tablegen(onnx_rewrite.inc -gen-rewriters) | |||
| add_public_tablegen_target(gen_onnx_rewrite) | ||||
| add_dependencies(compiler gen_onnx_rewrite) | ||||
| 
 | ||||
| add_subdirectory(interface) | ||||
| 
 | ||||
| set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td) | ||||
| onnx_mlir_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass") | ||||
| onnx_mlir_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass") | ||||
| set(GEN_DOC_FILE ${CMAKE_BINARY_DIR}/docs/Dialects/onnx.md) | ||||
| add_public_tablegen_target(gen_onnx) | ||||
| 
 | ||||
| add_dependencies(gen_onnx gen_shape_inference) | ||||
| add_dependencies(compiler gen_onnx) | ||||
| 
 | ||||
| # TODO: onnx_mlir_gen_promotable_const_operands_op_interface is really a | ||||
| # dependency of the onnx dialect library, which is currently part of `compiler`. | ||||
| add_dependencies(compiler onnx_mlir_gen_promotable_const_operands_op_interface) | ||||
| 
 | ||||
| 
 | ||||
| add_onnx_mlir_dialect_doc(onnx dialect/onnx/onnx.td) | ||||
| 
 | ||||
| set(LLVM_TARGET_DEFINITIONS dialect/krnl/krnl_ops.td) | ||||
|  | @ -66,14 +75,14 @@ target_include_directories(onnx_mlir_onnx_decompose | |||
|         PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} | ||||
|         ${ONNX_MLIR_SRC_ROOT}) | ||||
| target_link_libraries(onnx_mlir_onnx_decompose ${MLIRLibs}) | ||||
| add_dependencies(onnx_mlir_onnx_decompose gen_krnl_ops) | ||||
| add_dependencies(onnx_mlir_onnx_decompose gen_onnx) | ||||
| 
 | ||||
| add_library(onnx_mlir_shape_inference pass/shape_inference_pass.cpp) | ||||
| target_include_directories(onnx_mlir_shape_inference | ||||
|         PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} | ||||
|         ${ONNX_MLIR_SRC_ROOT}) | ||||
| target_link_libraries(onnx_mlir_shape_inference ${MLIRLibs}) | ||||
| add_dependencies(onnx_mlir_shape_inference gen_krnl_ops) | ||||
| add_dependencies(onnx_mlir_shape_inference gen_onnx) | ||||
| 
 | ||||
| add_library(onnx_mlir_lower_frontend | ||||
|         conversion/onnx_to_krnl/onnx_to_krnl_common.cpp | ||||
|  | @ -106,8 +115,17 @@ add_subdirectory(runtime) | |||
| 
 | ||||
| add_executable(onnx-mlir main.cpp) | ||||
| 
 | ||||
| target_link_libraries(onnx-mlir builder ${MLIRLibs} onnx_mlir_transform onnx_mlir_onnx_decompose onnx_mlir_shape_inference onnx_mlir_lower_frontend) | ||||
| whole_archive_link_mlir(onnx-mlir ${MLIRWholeArchiveLibs}) | ||||
| target_link_libraries(onnx-mlir | ||||
|         builder | ||||
|         ${MLIRLibs} | ||||
|         onnx_mlir_transform | ||||
|         onnx_mlir_onnx_decompose | ||||
|         onnx_mlir_shape_inference | ||||
|         onnx_mlir_lower_frontend | ||||
|         onnx_mlir_attribute_promotion) | ||||
| whole_archive_link_mlir(onnx-mlir | ||||
|         ${MLIRWholeArchiveLibs}) | ||||
| 
 | ||||
| find_package(ZLIB REQUIRED) | ||||
| target_link_libraries(onnx-mlir ${ZLIB_LIBRARIES}) | ||||
| 
 | ||||
|  |  | |||
|  | @ -22,6 +22,11 @@ include "mlir/IR/OpBase.td" | |||
| include "pass/shape_inference_interface.td" | ||||
| #endif // SHAPE_INFERENCE_INTERFACE | ||||
| 
 | ||||
| #ifdef PROMOTABLE_CONST_OPERANDS_OP_INTERFACE | ||||
| #else | ||||
| include "interface/promotable_const_operands_op_interface.td" | ||||
| #endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE | ||||
| 
 | ||||
| def ONNX_Dialect : Dialect { | ||||
|   let name = "onnx"; | ||||
|   let cppNamespace = ""; | ||||
|  | @ -48,7 +53,7 @@ class ONNX_Op<string mnemonic, list<OpTrait> traits = []> : | |||
| //  CC=gcc CXX=g++ pip install -e . | ||||
| //run the script | ||||
| //  python onnx/defs/gen_doc.py | ||||
| //result is in docs/onnxop.inc | ||||
| //result is in docs/onnx_ops.td.inc | ||||
| //current limitations: | ||||
| // 1. Attributes are not processed | ||||
| // 2. output type inference not implemented except Add | ||||
|  | @ -56,7 +61,7 @@ class ONNX_Op<string mnemonic, list<OpTrait> traits = []> : | |||
| // 4. type of string, complex64 and complex128 for input/output are ignored  | ||||
| // 5. unsigned int are treated as signed one | ||||
| 
 | ||||
| include "dialect/onnx/onnxop.inc" | ||||
| include "dialect/onnx/onnx_ops.td.inc" | ||||
| 
 | ||||
| // Indicate entry point functions of ONNX graph. | ||||
| def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> { | ||||
|  |  | |||
|  | @ -10,6 +10,9 @@ | |||
| 
 | ||||
| #pragma once | ||||
| 
 | ||||
| #include <map> | ||||
| #include <string> | ||||
| 
 | ||||
| #include "mlir/Dialect/StandardOps/Ops.h" | ||||
| #include "mlir/IR/Builders.h" | ||||
| #include "mlir/IR/Dialect.h" | ||||
|  | @ -17,6 +20,7 @@ | |||
| #include "mlir/IR/StandardTypes.h" | ||||
| 
 | ||||
| #include "src/pass/shape_inference_interface.hpp" | ||||
| #include "src/interface/promotable_const_operands_op_interface.hpp" | ||||
| 
 | ||||
| namespace mlir { | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,5 +1,5 @@ | |||
| //********************************************************
 | ||||
| //   This file is generated on UTC-02/24/2020, 06:44:13.
 | ||||
| //   This file is generated on UTC-03/18/2020, 03:36:58.
 | ||||
| //   Do not modify this file directly.
 | ||||
| //   This file is automatically generated via script.
 | ||||
| //   Details can be found in doc/readonnxdefs.md .
 | ||||
|  | @ -2501,7 +2501,7 @@ def ONNXReluOp:ONNX_Op<"Relu", | |||
| } | ||||
| 
 | ||||
| def ONNXReshapeOp:ONNX_Op<"Reshape", | ||||
|   [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { | ||||
|   [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"PromotableConstOperandsOpInterface">]> { | ||||
|   let summary = "ONNX Reshape operation"; | ||||
|   let description = [{ | ||||
|   "Reshape the input tensor similar to numpy.reshape." | ||||
|  | @ -2512,8 +2512,13 @@ def ONNXReshapeOp:ONNX_Op<"Reshape", | |||
|   "from the input tensor)." | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, | ||||
|     AnyTypeOf<[AnyMemRef, AnyTensor]>:$shape); | ||||
|     AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$shape); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reshaped); | ||||
|   let extraClassDeclaration = [{ | ||||
|     std::map<std::string, size_t> promotableConstOperands() { | ||||
|       return {{"shape", 1}}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def ONNXResizeOp:ONNX_Op<"Resize", | ||||
|  | @ -0,0 +1,14 @@ | |||
| set(LLVM_TARGET_DEFINITIONS promotable_const_operands_op_interface.td) | ||||
| onnx_mlir_tablegen(promotable_const_operands_op_interface.hpp.inc -gen-op-interface-decls) | ||||
| onnx_mlir_tablegen(promotable_const_operands_op_interface.cpp.inc -gen-op-interface-defs) | ||||
| add_public_tablegen_target(onnx_mlir_gen_promotable_const_operands_op_interface) | ||||
| 
 | ||||
| add_library(onnx_mlir_promotable_const_operands_op_interface | ||||
|         promotable_const_operands_op_interface.hpp | ||||
|         promotable_const_operands_op_interface.cpp) | ||||
| target_include_directories(onnx_mlir_promotable_const_operands_op_interface | ||||
|         PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} | ||||
|         ${ONNX_MLIR_SRC_ROOT}) | ||||
| 
 | ||||
| add_dependencies(onnx_mlir_promotable_const_operands_op_interface | ||||
|         onnx_mlir_gen_promotable_const_operands_op_interface) | ||||
|  | @ -0,0 +1,23 @@ | |||
| //===------------ promotable_const_operands_op_interface.cpp --------------===//
 | ||||
| //===-------- Promotable Const Operands Op Interface Definition -----------===//
 | ||||
| //
 | ||||
| // Copyright 2020 The IBM Research Authors.
 | ||||
| //
 | ||||
| // =============================================================================
 | ||||
| //
 | ||||
| // This file contains the definition of the promotable const operands op
 | ||||
| // interface.
 | ||||
| //
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| 
 | ||||
| #include "src/interface/promotable_const_operands_op_interface.hpp" | ||||
| 
 | ||||
| using namespace mlir; | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Promotable Const Operands Op Interface
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| #include "src/interface/promotable_const_operands_op_interface.cpp.inc" | ||||
| 
 | ||||
|  | @ -0,0 +1,25 @@ | |||
| //===------------ promotable_const_operands_op_interface.cpp --------------===//
 | ||||
| //===-------- Promotable Const Operands Op Interface Definition -----------===//
 | ||||
| //
 | ||||
| // Copyright 2020 The IBM Research Authors.
 | ||||
| //
 | ||||
| // =============================================================================
 | ||||
| //
 | ||||
| // This file contains the declaration of the promotable const operands op
 | ||||
| // interface.
 | ||||
| //
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| #pragma once | ||||
| 
 | ||||
| #include <map> | ||||
| #include <string> | ||||
| 
 | ||||
| #include "mlir/IR/OpDefinition.h" | ||||
| 
 | ||||
| namespace mlir { | ||||
| 
 | ||||
| /// Include the auto-generated declarations.
 | ||||
| #include "src/interface/promotable_const_operands_op_interface.hpp.inc" | ||||
| 
 | ||||
| }  // end namespace mlir
 | ||||
|  | @ -0,0 +1,34 @@ | |||
| //===------------- promotable_const_operands_op_interface.td --------------===// | ||||
| //===---- Promotable Const Operands Op Interface TableGen Definition ------===// | ||||
| // | ||||
| // Copyright 2020 The IBM Research Authors. | ||||
| // | ||||
| // ============================================================================= | ||||
| // | ||||
| // This file contains the TableGen definition of the promotable const operands | ||||
| // op interface. | ||||
| // | ||||
| //===----------------------------------------------------------------------===// | ||||
| 
 | ||||
| #ifdef PROMOTABLE_CONST_OPERANDS_OP_INTERFACE | ||||
| #else | ||||
| #define PROMOTABLE_CONST_OPERANDS_OP_INTERFACE | ||||
| 
 | ||||
| #ifdef OP_BASE | ||||
| #else | ||||
| include "mlir/IR/OpBase.td" | ||||
| #endif // OP_BASE | ||||
| 
 | ||||
| def PromotableConstOperandsOpInterface : OpInterface<"PromotableConstOperandsOpInterface"> { | ||||
|   let description = [{ | ||||
|     Interface to access a registered method to infer the return types for an | ||||
|     operation that can be used during type inference. | ||||
|   }]; | ||||
| 
 | ||||
|   let methods = [ | ||||
|     InterfaceMethod<"Infer and set the output shape for the current operation.", | ||||
|                     "std::map<std::string, size_t>", "promotableConstOperands"> | ||||
|   ]; | ||||
| } | ||||
| 
 | ||||
| #endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE | ||||
|  | @ -126,6 +126,7 @@ int main(int argc, char *argv[]) { | |||
|   pm.addPass(mlir::createShapeInferencePass()); | ||||
|   pm.addPass(mlir::createCanonicalizerPass()); | ||||
|   pm.addPass(mlir::createShapeInferencePass()); | ||||
|   pm.addPass(mlir::createAttributePromotionPass()); | ||||
| 
 | ||||
|   if (emissionTarget >= EmitMLIR) { | ||||
|     pm.addPass(mlir::createLowerToKrnlPass()); | ||||
|  |  | |||
|  | @ -20,6 +20,9 @@ std::unique_ptr<Pass> createDecomposeONNXToONNXPass(); | |||
| 
 | ||||
| std::unique_ptr<Pass> createShapeInferencePass(); | ||||
| 
 | ||||
| /// Pass for promoting constant operands to attributes.
 | ||||
| std::unique_ptr<Pass> createAttributePromotionPass(); | ||||
| 
 | ||||
| /// Add pass for lowering to Krnl IR.
 | ||||
| std::unique_ptr<Pass> createLowerToKrnlPass(); | ||||
| 
 | ||||
|  |  | |||
|  | @ -12,9 +12,9 @@ | |||
| #include "mlir/Pass/Pass.h" | ||||
| #include "llvm/ADT/SmallPtrSet.h" | ||||
| #include "llvm/Support/raw_ostream.h" | ||||
| #include "mlir/IR/StandardTypes.h" | ||||
| 
 | ||||
| #include "shape_inference_interface.hpp" | ||||
| #include "src/dialect/onnx/onnx_ops.hpp" | ||||
| 
 | ||||
| #include "passes.hpp" | ||||
| 
 | ||||
|  |  | |||
|  | @ -4,6 +4,19 @@ add_dependencies(onnx-mlir-opt gen_krnl_ops) | |||
| target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_SRC_ROOT}) | ||||
| target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_BIN_ROOT}) | ||||
| 
 | ||||
| target_link_libraries(onnx-mlir-opt builder ${MLIRLibs} onnx_mlir_transform onnx_mlir_shape_inference onnx_mlir_lower_frontend curses) | ||||
| whole_archive_link_mlir(onnx-mlir-opt ${MLIRWholeArchiveLibs}) | ||||
| whole_archive_link_onnx_mlir(onnx-mlir-opt compiler onnx_mlir_transform onnx_mlir_lower_frontend onnx_mlir_shape_inference) | ||||
| target_link_libraries(onnx-mlir-opt | ||||
|         builder | ||||
|         ${MLIRLibs} | ||||
|         onnx_mlir_transform | ||||
|         onnx_mlir_shape_inference | ||||
|         onnx_mlir_lower_frontend | ||||
|         onnx_mlir_promotable_const_operands_op_interface | ||||
|         curses) | ||||
| whole_archive_link_mlir(onnx-mlir-opt | ||||
|         ${MLIRWholeArchiveLibs}) | ||||
| whole_archive_link_onnx_mlir(onnx-mlir-opt | ||||
|         compiler | ||||
|         onnx_mlir_transform | ||||
|         onnx_mlir_lower_frontend | ||||
|         onnx_mlir_shape_inference | ||||
|         onnx_mlir_attribute_promotion) | ||||
|  | @ -7,3 +7,5 @@ target_include_directories(onnx_mlir_transform | |||
|                                    ${ONNX_MLIR_SRC_ROOT}) | ||||
| target_link_libraries(onnx_mlir_transform ${MLIRLibs}) | ||||
| add_dependencies(onnx_mlir_transform gen_krnl_ops) | ||||
| 
 | ||||
| add_subdirectory(onnx) | ||||
|  | @ -0,0 +1,7 @@ | |||
| add_library(onnx_mlir_attribute_promotion | ||||
|         attribute_promotion.cpp) | ||||
| target_include_directories(onnx_mlir_attribute_promotion | ||||
|         PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} | ||||
|         ${ONNF_MLIR_SRC_ROOT}) | ||||
| target_link_libraries(onnx_mlir_attribute_promotion | ||||
|         onnx_mlir_promotable_const_operands_op_interface) | ||||
|  | @ -0,0 +1,92 @@ | |||
| //===----- attribute_promotion.cpp - Attribute Promotion
 | ||||
| //-------------------===//
 | ||||
| //
 | ||||
| // Copyright 2020 The IBM Research Authors.
 | ||||
| //
 | ||||
| // =============================================================================
 | ||||
| //
 | ||||
| // This file implements a function level pass to move an operand to become
 | ||||
| // an attribute if desirable and legal.
 | ||||
| //
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| #include "mlir/Dialect/StandardOps/Ops.h" | ||||
| #include "mlir/IR/Builders.h" | ||||
| #include "mlir/IR/PatternMatch.h" | ||||
| #include "mlir/Pass/Pass.h" | ||||
| 
 | ||||
| #include "src/interface/promotable_const_operands_op_interface.hpp" | ||||
| #include "src/pass/passes.hpp" | ||||
| 
 | ||||
| using namespace mlir; | ||||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| /*!
 | ||||
|  * Helper function to create a NoneTyped constant value if `none` is empty. | ||||
|  */ | ||||
| void getOrCreateNoneValue(llvm::Optional<mlir::Value> &none, FuncOp f) { | ||||
|   if (none.hasValue()) | ||||
|     return; | ||||
| 
 | ||||
|   OpBuilder builder(f.getContext()); | ||||
|   builder.setInsertionPointToStart(&f.front()); | ||||
|   none = builder.create<mlir::ConstantOp>(f.getLoc(), builder.getUnitAttr()); | ||||
| } | ||||
| 
 | ||||
| /*!
 | ||||
|  *  FunctionPass that performs attribute promotion by iterating over a list of | ||||
|  *  candidate operations and moves constant operands to attributes whenever | ||||
|  *  desirable (as instructed by the PromotableConstOperandsOpInterface). | ||||
|  */ | ||||
| class AttributePromotionPass | ||||
|     : public mlir::FunctionPass<AttributePromotionPass> { | ||||
| public: | ||||
|   void runOnFunction() override { | ||||
|     auto f = getFunction(); | ||||
| 
 | ||||
|     // A function-scope shared none value used to indicate an missing operand.
 | ||||
|     llvm::Optional<mlir::Value> none; | ||||
| 
 | ||||
|     // Iterate on the operations that may need attribute promotion.
 | ||||
|     f.walk([&](mlir::Operation *op) { | ||||
|       if (PromotableConstOperandsOpInterface opWithConstOperand = | ||||
|               dyn_cast<PromotableConstOperandsOpInterface>(op)) { | ||||
|         auto promotableOperands = opWithConstOperand.promotableConstOperands(); | ||||
|         for (const auto &operandNameToIdx : promotableOperands) { | ||||
|           auto name = operandNameToIdx.first; | ||||
|           auto i = operandNameToIdx.second; | ||||
| 
 | ||||
|           // If the i-th operand is defined by an constant operation, then
 | ||||
|           // move it to an attribute, and use None to indicate the absence
 | ||||
|           // of the original operand value.
 | ||||
|           auto operandToPromote = op->getOperand(i); | ||||
|           if (auto constantOp = dyn_cast_or_null<ConstantOp>( | ||||
|                   operandToPromote.getDefiningOp())) { | ||||
|             op->setAttr(name, constantOp.value()); | ||||
|             getOrCreateNoneValue(none, f); | ||||
|             op->setOperand(i, *none); | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|     }); | ||||
| 
 | ||||
|     // Dispatch canonicalization pattern rewriters to eliminate redundant
 | ||||
|     // constant operaions.
 | ||||
|     OwningRewritePatternList patterns; | ||||
|     auto *context = &getContext(); | ||||
|     ConstantOp::getCanonicalizationPatterns(patterns, context); | ||||
|     applyPatternsGreedily(f, patterns); | ||||
|   } | ||||
| }; | ||||
| } // end anonymous namespace
 | ||||
| 
 | ||||
| /*!
 | ||||
|  * Create a Attribute Promotion pass. | ||||
|  */ | ||||
| std::unique_ptr<mlir::Pass> mlir::createAttributePromotionPass() { | ||||
|   return std::make_unique<AttributePromotionPass>(); | ||||
| } | ||||
| 
 | ||||
| static PassRegistration<AttributePromotionPass> pass( | ||||
|     "attribute-promotion", "Promote constant operands to attributes."); | ||||
|  | @ -0,0 +1,32 @@ | |||
| // RUN: onnx-mlir-opt --attribute-promotion %s -split-input-file | FileCheck %s | ||||
| 
 | ||||
| func @test_should_promote_to_attribute(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { | ||||
|   %shape = constant dense<[6, 7, 42]> : tensor<3xi32> | ||||
|   %0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi32>) -> tensor<*xf32> | ||||
|   return %0 : tensor<*xf32> | ||||
|   // CHECK-LABEL: test_should_promote_to_attribute | ||||
|   // CHECK-NEXT: [[NONE:%.+]] = constant unit | ||||
|   // CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi32>} : (tensor<?x10xf32>, none) -> tensor<*xf32> | ||||
|   // CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32> | ||||
| } | ||||
| 
 | ||||
| func @test_should_not_promote_to_attribute(%arg0 : tensor<?x10xf32>, %arg1 : tensor<*xi64>) -> tensor<*xf32> { | ||||
|   %0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<*xi64>) -> tensor<*xf32> | ||||
|   return %0 : tensor<*xf32> | ||||
|   // CHECK-LABEL: test_should_not_promote_to_attribute | ||||
|   // CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, %{{.*}}) : (tensor<?x10xf32>, tensor<*xi64>) -> tensor<*xf32> | ||||
|   // CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32> | ||||
| } | ||||
| 
 | ||||
| func @test_promote_to_attribute_without_removing_const_op(%arg0 : tensor<?x10xf32>) -> (tensor<*xf32>, tensor<*xf32>) { | ||||
|   %shape = constant dense<[6, 7, 42]> : tensor<3xi32> | ||||
|   %0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi32>) -> tensor<*xf32> | ||||
|   %1 = "onnx.Identity"(%shape) : (tensor<3xi32>) -> tensor<*xf32> | ||||
|   "std.return"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> () | ||||
|   // CHECK-LABEL: test_promote_to_attribute_without_removing_const_op | ||||
|   // CHECK-NEXT: [[NONE:%.+]] = constant unit | ||||
|   // CHECK-NEXT: [[SHAPE:%.+]] = constant dense<[6, 7, 42]> : tensor<3xi32> | ||||
|   // CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi32>} : (tensor<?x10xf32>, none) -> tensor<*xf32> | ||||
|   // CHECK-NEXT: [[IDENTITY:%.+]] = "onnx.Identity"([[SHAPE]]) : (tensor<3xi32>) -> tensor<*xf32> | ||||
|   // CHECK-NEXT: return [[RESHAPE]], [[IDENTITY]] : tensor<*xf32>, tensor<*xf32> | ||||
| } | ||||
		Loading…
	
		Reference in New Issue