fix type inference (#144)
* fix type inference for ConstantOp and MaxPoolSingleOut * modify interface * use OpInterface * change name to OpInterface * Builder dependence * Update CMakeLists.txt * Update CMakeLists.txt Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
		
							parent
							
								
									6099efd91b
								
							
						
					
					
						commit
						5a542e78a0
					
				|  | @ -27,7 +27,7 @@ target_include_directories(OMBuilder | ||||||
| # variable definitions when building onnx such as -DONNX_ML=1 -DONNX_NAMESPACE=onnx | # variable definitions when building onnx such as -DONNX_ML=1 -DONNX_NAMESPACE=onnx | ||||||
| # will NOT be carried over when compiling FrontendDialectHelper.cpp, etc. so | # will NOT be carried over when compiling FrontendDialectHelper.cpp, etc. so | ||||||
| # the compilation will fail. | # the compilation will fail. | ||||||
| add_dependencies(OMBuilder OMONNXOps) | add_dependencies(OMBuilder OMONNXOps OMResultTypeInferenceOpInterface) | ||||||
| 
 | 
 | ||||||
| if (INCLUDE_ONNX_ML) | if (INCLUDE_ONNX_ML) | ||||||
|   add_dependencies(OMBuilder OMMLONNXOps) |   add_dependencies(OMBuilder OMMLONNXOps) | ||||||
|  |  | ||||||
|  | @ -20,6 +20,8 @@ | ||||||
| #include <mpark/variant.hpp> | #include <mpark/variant.hpp> | ||||||
| namespace bstd = mpark; | namespace bstd = mpark; | ||||||
| 
 | 
 | ||||||
|  | #include "src/Interface/ResultTypeInferenceOpInterface.hpp" | ||||||
|  | 
 | ||||||
| #include "FrontendDialectTransformer.hpp" | #include "FrontendDialectTransformer.hpp" | ||||||
| 
 | 
 | ||||||
| namespace onnx_mlir { | namespace onnx_mlir { | ||||||
|  | @ -296,6 +298,17 @@ private: | ||||||
| 
 | 
 | ||||||
|     // TODO: Handle optional inputs.
 |     // TODO: Handle optional inputs.
 | ||||||
|     auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes); |     auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes); | ||||||
|  | 
 | ||||||
|  |     // Type inference for results.
 | ||||||
|  |     if (auto opWithTypeInference = | ||||||
|  |             mlir::dyn_cast<mlir::ResultTypeInferenceOpInterface>( | ||||||
|  |                 op.getOperation())) { | ||||||
|  |       auto outTypes = opWithTypeInference.resultTypeInference(); | ||||||
|  |       for (int i = 0; i < node.output().size(); i++) { | ||||||
|  |         (*op.getODSResults(i).begin()).setType(outTypes[i]); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     for (int i = 0; i < node.output().size(); i++) { |     for (int i = 0; i < node.output().size(); i++) { | ||||||
|       frontend_symbols_.AddMapping( |       frontend_symbols_.AddMapping( | ||||||
|           legalize_name(node.output()[i]), *(op.getODSResults(i).begin())); |           legalize_name(node.output()[i]), *(op.getODSResults(i).begin())); | ||||||
|  |  | ||||||
|  | @ -46,6 +46,7 @@ target_link_libraries(onnx-mlir | ||||||
|         OMShapeInferenceOpInterface |         OMShapeInferenceOpInterface | ||||||
|         OMAttributePromotion |         OMAttributePromotion | ||||||
|         OMPromotableConstOperandsOpInterface |         OMPromotableConstOperandsOpInterface | ||||||
|  |         OMResultTypeInferenceOpInterface | ||||||
|         OMElideConstants |         OMElideConstants | ||||||
|         OMElideKrnlGlobalConstants |         OMElideKrnlGlobalConstants | ||||||
|         OMKrnlToAffine |         OMKrnlToAffine | ||||||
|  |  | ||||||
|  | @ -17,6 +17,7 @@ add_dependencies(OMMLONNXOps OMMLONNXOpsIncGen) | ||||||
| # Linking dependencies: | # Linking dependencies: | ||||||
| add_dependencies(OMMLONNXOps | add_dependencies(OMMLONNXOps | ||||||
|         OMPromotableConstOperandsOpInterface |         OMPromotableConstOperandsOpInterface | ||||||
|  |         OMResultTypeInferenceOpInterface | ||||||
|         OMShapeInferenceOpInterface) |         OMShapeInferenceOpInterface) | ||||||
| 
 | 
 | ||||||
| add_onnx_mlir_dialect_doc(mlonnx MLONNXOps.td) | add_onnx_mlir_dialect_doc(mlonnx MLONNXOps.td) | ||||||
|  |  | ||||||
|  | @ -20,6 +20,7 @@ | ||||||
| #include "mlir/IR/StandardTypes.h" | #include "mlir/IR/StandardTypes.h" | ||||||
| 
 | 
 | ||||||
| #include "src/Interface/PromotableConstOperandsOpInterface.hpp" | #include "src/Interface/PromotableConstOperandsOpInterface.hpp" | ||||||
|  | #include "src/Interface/ResultTypeInferenceOpInterface.hpp" | ||||||
| #include "src/Interface/ShapeInferenceInterface.hpp" | #include "src/Interface/ShapeInferenceInterface.hpp" | ||||||
| 
 | 
 | ||||||
| namespace mlir { | namespace mlir { | ||||||
|  |  | ||||||
|  | @ -27,6 +27,11 @@ include "src/Interface/ShapeInferenceInterface.td" | ||||||
| include "src/Interface/PromotableConstOperandsOpInterface.td" | include "src/Interface/PromotableConstOperandsOpInterface.td" | ||||||
| #endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE | #endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE | ||||||
| 
 | 
 | ||||||
|  | #ifdef RESULT_TYPE_INFERENCE_OP_INTERFACE | ||||||
|  | #else | ||||||
|  | include "src/Interface/ResultTypeInferenceOpInterface.td" | ||||||
|  | #endif // RESULT_TYPE_INFERENCE_OP_INTERFACE | ||||||
|  | 
 | ||||||
| def MLONNX_Dialect : Dialect { | def MLONNX_Dialect : Dialect { | ||||||
|   let name = "mlonnx"; |   let name = "mlonnx"; | ||||||
|   let cppNamespace = ""; |   let cppNamespace = ""; | ||||||
|  |  | ||||||
|  | @ -18,6 +18,7 @@ add_dependencies(OMONNXOps OMONNXOpsIncGen) | ||||||
| # Linking dependencies: | # Linking dependencies: | ||||||
| add_dependencies(OMONNXOps | add_dependencies(OMONNXOps | ||||||
|         OMPromotableConstOperandsOpInterface |         OMPromotableConstOperandsOpInterface | ||||||
|  |         OMResultTypeInferenceOpInterface | ||||||
|         OMShapeInferenceOpInterface) |         OMShapeInferenceOpInterface) | ||||||
| 
 | 
 | ||||||
| add_onnx_mlir_dialect_doc(onnx ONNXOps.td) | add_onnx_mlir_dialect_doc(onnx ONNXOps.td) | ||||||
|  |  | ||||||
|  | @ -20,6 +20,7 @@ | ||||||
| #include "mlir/IR/StandardTypes.h" | #include "mlir/IR/StandardTypes.h" | ||||||
| 
 | 
 | ||||||
| #include "src/Interface/PromotableConstOperandsOpInterface.hpp" | #include "src/Interface/PromotableConstOperandsOpInterface.hpp" | ||||||
|  | #include "src/Interface/ResultTypeInferenceOpInterface.hpp" | ||||||
| #include "src/Interface/ShapeInferenceInterface.hpp" | #include "src/Interface/ShapeInferenceInterface.hpp" | ||||||
| 
 | 
 | ||||||
| namespace mlir { | namespace mlir { | ||||||
|  |  | ||||||
|  | @ -27,6 +27,11 @@ include "src/Interface/ShapeInferenceInterface.td" | ||||||
| include "src/Interface/PromotableConstOperandsOpInterface.td" | include "src/Interface/PromotableConstOperandsOpInterface.td" | ||||||
| #endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE | #endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE | ||||||
| 
 | 
 | ||||||
|  | #ifdef RESULT_TYPE_INFERENCE_OP_INTERFACE | ||||||
|  | #else | ||||||
|  | include "src/Interface/ResultTypeInferenceOpInterface.td" | ||||||
|  | #endif // RESULT_TYPE_INFERENCE_OP_INTERFACE | ||||||
|  | 
 | ||||||
| def ONNX_Dialect : Dialect { | def ONNX_Dialect : Dialect { | ||||||
|   let name = "onnx"; |   let name = "onnx"; | ||||||
|   let cppNamespace = ""; |   let cppNamespace = ""; | ||||||
|  | @ -120,7 +125,7 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", | ||||||
|       return 1; |       return 1; | ||||||
|     } |     } | ||||||
|     static std::vector<int> getTypeMap() { |     static std::vector<int> getTypeMap() { | ||||||
|       return {0}; |       return {20}; | ||||||
|     } |     } | ||||||
|   }]; |   }]; | ||||||
| } | } | ||||||
|  | @ -156,7 +161,7 @@ def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode", | ||||||
|       return 1; |       return 1; | ||||||
|     } |     } | ||||||
|     static std::vector<int> getTypeMap() { |     static std::vector<int> getTypeMap() { | ||||||
|       return {0}; |       return {20}; | ||||||
|     } |     } | ||||||
|   }]; |   }]; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -559,7 +559,7 @@ def ONNXConcatFromSequenceOp:ONNX_Op<"ConcatFromSequence", | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| def ONNXConstantOp:ONNX_Op<"Constant", | def ONNXConstantOp:ONNX_Op<"Constant", | ||||||
|   [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { |   [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"ResultTypeInferenceOpInterface">]> { | ||||||
|   let summary = "ONNX Constant operation"; |   let summary = "ONNX Constant operation"; | ||||||
|   let description = [{ |   let description = [{ | ||||||
|   "A constant tensor. Exactly one of the two attributes, either value or sparse_value," |   "A constant tensor. Exactly one of the two attributes, either value or sparse_value," | ||||||
|  | @ -578,6 +578,15 @@ def ONNXConstantOp:ONNX_Op<"Constant", | ||||||
|     static std::vector<int> getTypeMap() { |     static std::vector<int> getTypeMap() { | ||||||
|       return {-1}; |       return {-1}; | ||||||
|     } |     } | ||||||
|  |     std::vector<mlir::Type> resultTypeInference() { | ||||||
|  |       std::vector<mlir::Type> resultTypes; | ||||||
|  |       if (auto attr = valueAttr()) { | ||||||
|  |         resultTypes.push_back(attr.getType()); | ||||||
|  |       } else if (auto attr = sparse_valueAttr()) { | ||||||
|  |         resultTypes.push_back(attr.getType()); | ||||||
|  |       } | ||||||
|  |       return resultTypes; | ||||||
|  |     } | ||||||
|   }]; |   }]; | ||||||
|     let builders = [ |     let builders = [ | ||||||
|     OpBuilder<"OpBuilder &builder, OperationState &state, Attribute sparse_value, Attribute value", [{ |     OpBuilder<"OpBuilder &builder, OperationState &state, Attribute sparse_value, Attribute value", [{ | ||||||
|  | @ -589,7 +598,8 @@ def ONNXConstantOp:ONNX_Op<"Constant", | ||||||
|         build(builder, state, tensorType, sparse_value, value); |         build(builder, state, tensorType, sparse_value, value); | ||||||
|       } |       } | ||||||
|     }]> |     }]> | ||||||
|     ];} |     ]; | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
| def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape", | def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape", | ||||||
|   [NoSideEffect]> { |   [NoSideEffect]> { | ||||||
|  |  | ||||||
|  | @ -25,3 +25,17 @@ target_include_directories(OMShapeInferenceOpInterface | ||||||
|         PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} |         PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} | ||||||
|         ${ONNX_MLIR_SRC_ROOT}) |         ${ONNX_MLIR_SRC_ROOT}) | ||||||
| add_dependencies(OMShapeInferenceOpInterface ShapeInferenceOpInterfaceIncGen) | add_dependencies(OMShapeInferenceOpInterface ShapeInferenceOpInterfaceIncGen) | ||||||
|  | 
 | ||||||
|  | set(LLVM_TARGET_DEFINITIONS ResultTypeInferenceOpInterface.td) | ||||||
|  | onnx_mlir_tablegen(ResultTypeInferenceOpInterface.hpp.inc -gen-op-interface-decls) | ||||||
|  | onnx_mlir_tablegen(ResultTypeInferenceOpInterface.cpp.inc -gen-op-interface-defs) | ||||||
|  | add_public_tablegen_target(OMResultTypeInferenceOpInterfaceIncGen) | ||||||
|  | 
 | ||||||
|  | add_library(OMResultTypeInferenceOpInterface | ||||||
|  |         ResultTypeInferenceOpInterface.hpp | ||||||
|  |         ResultTypeInferenceOpInterface.cpp) | ||||||
|  | target_include_directories(OMResultTypeInferenceOpInterface | ||||||
|  |         PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} | ||||||
|  |         ${ONNX_MLIR_SRC_ROOT}) | ||||||
|  | add_dependencies(OMResultTypeInferenceOpInterface | ||||||
|  |         OMResultTypeInferenceOpInterfaceIncGen) | ||||||
|  |  | ||||||
|  | @ -0,0 +1,20 @@ | ||||||
|  | //===------------ ResultTypeInferenceOpInterface.cpp --------------===//
 | ||||||
|  | //===--------- Infer Data Type for Results Interface Definition --------===//
 | ||||||
|  | //
 | ||||||
|  | // Copyright 2020 The IBM Research Authors.
 | ||||||
|  | //
 | ||||||
|  | // =============================================================================
 | ||||||
|  | //
 | ||||||
|  | // This file contains the implementation of the data type inference interface.
 | ||||||
|  | //
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | 
 | ||||||
|  | #include "ResultTypeInferenceOpInterface.hpp" | ||||||
|  | 
 | ||||||
|  | using namespace mlir; | ||||||
|  | 
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | // Infer Data Type for Results Interface
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | 
 | ||||||
|  | #include "src/Interface/ResultTypeInferenceOpInterface.cpp.inc" | ||||||
|  | @ -0,0 +1,25 @@ | ||||||
|  | //===------------ ResultTypeInferenceOpInterface.hpp --------------===//
 | ||||||
|  | //===------- Infer Data Type for Result of Op Interface Definition -------===//
 | ||||||
|  | //
 | ||||||
|  | // Copyright 2020 The IBM Research Authors.
 | ||||||
|  | //
 | ||||||
|  | // =============================================================================
 | ||||||
|  | //
 | ||||||
|  | // This file contains the declaration of the data type reference for op
 | ||||||
|  | // interface.
 | ||||||
|  | //
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | 
 | ||||||
|  | #pragma once | ||||||
|  | 
 | ||||||
|  | #include <map> | ||||||
|  | #include <string> | ||||||
|  | 
 | ||||||
|  | #include "mlir/IR/OpDefinition.h" | ||||||
|  | 
 | ||||||
|  | namespace mlir { | ||||||
|  | 
 | ||||||
|  | /// Include the auto-generated declarations.
 | ||||||
|  | #include "src/Interface/ResultTypeInferenceOpInterface.hpp.inc" | ||||||
|  | 
 | ||||||
|  | } // end namespace mlir
 | ||||||
|  | @ -0,0 +1,34 @@ | ||||||
|  | //===------------ ResultTypeInferenceOpInterface.td --------------===// | ||||||
|  | //===--------- Infer Data Type for Results Interface Definition --------===// | ||||||
|  | // | ||||||
|  | // Copyright 2020 The IBM Research Authors. | ||||||
|  | // | ||||||
|  | // ============================================================================= | ||||||
|  | // | ||||||
|  | // This file contains the tablegen of the data type inference interface. | ||||||
|  | // | ||||||
|  | //===----------------------------------------------------------------------===// | ||||||
|  | 
 | ||||||
|  | #ifdef RESULT_TYPE_INFERENCE_OP_INTERFACE | ||||||
|  | #else | ||||||
|  | #define RESULT_TYPE_INFERENCE_OP_INTERFACE | ||||||
|  | 
 | ||||||
|  | #ifdef OP_BASE | ||||||
|  | #else | ||||||
|  | include "mlir/IR/OpBase.td" | ||||||
|  | #endif // OP_BASE | ||||||
|  | 
 | ||||||
|  | def ResultTypeInferenceOpInterface : OpInterface<"ResultTypeInferenceOpInterface"> { | ||||||
|  |   let description = [{ | ||||||
|  |     Interface to access a registered method to infer the data types for  | ||||||
|  |     the result of an operation | ||||||
|  |   }]; | ||||||
|  | 
 | ||||||
|  |   let methods = [ | ||||||
|  |       InterfaceMethod<"Infer output data type for this operation class.", | ||||||
|  |           "std::vector<mlir::Type>", "resultTypeInference" | ||||||
|  |       > | ||||||
|  |   ]; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #endif // RESULT_TYPE_INFERENCE_OP_INTERFACE | ||||||
|  | @ -271,6 +271,17 @@ OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv'] | ||||||
| OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)], | OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)], | ||||||
|                                   "Pad": [("pads", 1), ("constant_value", 2)]} |                                   "Pad": [("pads", 1), ("constant_value", 2)]} | ||||||
| 
 | 
 | ||||||
|  | # Interface for special handling of type inference | ||||||
|  | # The common code are put into get_type_inference_func | ||||||
|  | OpsWithResultTypeInference = { | ||||||
|  |   "Constant": | ||||||
|  |   '''if (auto attr = valueAttr()) { | ||||||
|  |         resultTypes.push_back(attr.getType()); | ||||||
|  |       } else if (auto attr = sparse_valueAttr()) { | ||||||
|  |         resultTypes.push_back(attr.getType()); | ||||||
|  |       }''' | ||||||
|  | } | ||||||
|  |         | ||||||
| # Add an Op in this list if the Op needs result type deduction which is required | # Add an Op in this list if the Op needs result type deduction which is required | ||||||
| # when writing declarative rewriting rules. Deduced type is always | # when writing declarative rewriting rules. Deduced type is always | ||||||
| # an UnrankedTensorType whose element type is the same as the first operand's | # an UnrankedTensorType whose element type is the same as the first operand's | ||||||
|  | @ -634,6 +645,23 @@ def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx): | ||||||
| 
 | 
 | ||||||
|     return s |     return s | ||||||
| 
 | 
 | ||||||
|  | def get_type_inference_func(s, indent, type_inference_code): | ||||||
|  |     indent = inc_indent(indent) | ||||||
|  | 
 | ||||||
|  |     s += indent + "std::vector<mlir::Type> resultTypeInference() {" + "\n" | ||||||
|  |     indent = inc_indent(indent) | ||||||
|  |     s += indent + "std::vector<mlir::Type> resultTypes;" + "\n" | ||||||
|  | 
 | ||||||
|  |     s += indent + type_inference_code + '\n' | ||||||
|  | 
 | ||||||
|  |     s += indent + "return resultTypes;" + "\n" | ||||||
|  |     indent = dec_indent(indent) | ||||||
|  |     s += indent + "}" + "\n" | ||||||
|  | 
 | ||||||
|  |     indent = dec_indent(indent) | ||||||
|  |     return s | ||||||
|  |    | ||||||
|  |    | ||||||
| 
 | 
 | ||||||
| def gen_op_def(schema): | def gen_op_def(schema): | ||||||
|     indent = inc_indent() |     indent = inc_indent() | ||||||
|  | @ -648,6 +676,8 @@ def gen_op_def(schema): | ||||||
|         traits.append("DeclareOpInterfaceMethods<ShapeInferenceOpInterface>") |         traits.append("DeclareOpInterfaceMethods<ShapeInferenceOpInterface>") | ||||||
|     if schema.name in OpsWithPromotableConstOperands.keys(): |     if schema.name in OpsWithPromotableConstOperands.keys(): | ||||||
|         traits.append("OpInterface<\"PromotableConstOperandsOpInterface\">") |         traits.append("OpInterface<\"PromotableConstOperandsOpInterface\">") | ||||||
|  |     if schema.name in OpsWithResultTypeInference.keys(): | ||||||
|  |         traits.append("OpInterface<\"ResultTypeInferenceOpInterface\">") | ||||||
|     s += inc_indent(indent) + '[{}]> {{\n'.format(join_args(traits)) |     s += inc_indent(indent) + '[{}]> {{\n'.format(join_args(traits)) | ||||||
| 
 | 
 | ||||||
|     # Generate decl for canonicalizer. |     # Generate decl for canonicalizer. | ||||||
|  | @ -739,10 +769,14 @@ def gen_op_def(schema): | ||||||
|         s = get_promotable_const_operands_func( |         s = get_promotable_const_operands_func( | ||||||
|             s, indent, OpsWithPromotableConstOperands[schema.name]) |             s, indent, OpsWithPromotableConstOperands[schema.name]) | ||||||
| 
 | 
 | ||||||
|  |     if schema.name in OpsWithResultTypeInference: | ||||||
|  |         s = get_type_inference_func( | ||||||
|  |             s, indent, OpsWithResultTypeInference[schema.name]) | ||||||
|  | 
 | ||||||
|     s += indent + '}];\n' |     s += indent + '}];\n' | ||||||
| 
 | 
 | ||||||
|     if ( schema.name in custom_definition_misc) : |     if ( schema.name in custom_definition_misc) : | ||||||
|         s += custom_definition_misc[schema.name] |         s += custom_definition_misc[schema.name] + '\n' | ||||||
| 
 | 
 | ||||||
|     s += '}\n\n' |     s += '}\n\n' | ||||||
|     return s |     return s | ||||||
|  | @ -852,6 +886,7 @@ def build_operator_schemas(): | ||||||
|                         print("Your onnx may be too old." |                         print("Your onnx may be too old." | ||||||
|                            "right version for opertion {} not found".format( |                            "right version for opertion {} not found".format( | ||||||
|                             schema.name)) |                             schema.name)) | ||||||
|  |                         sys.exit() | ||||||
|             processed_supportmap.append((_support, processed_namemap)) |             processed_supportmap.append((_support, processed_namemap)) | ||||||
|         operator_schemas.append((domain, processed_supportmap)) |         operator_schemas.append((domain, processed_supportmap)) | ||||||
|     return operator_schemas |     return operator_schemas | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue