fix type inference (#144)
* fix type inference for ConstantOp and MaxPoolSingleOut * modify interface * use OpInterface * change name to OpInterface * Builder dependence * Update CMakeLists.txt * Update CMakeLists.txt Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
		
							parent
							
								
									6099efd91b
								
							
						
					
					
						commit
						5a542e78a0
					
				|  | @ -27,7 +27,7 @@ target_include_directories(OMBuilder | |||
| # variable definitions when building onnx such as -DONNX_ML=1 -DONNX_NAMESPACE=onnx | ||||
| # will NOT be carried over when compiling FrontendDialectHelper.cpp, etc. so | ||||
| # the compilation will fail. | ||||
| add_dependencies(OMBuilder OMONNXOps) | ||||
| add_dependencies(OMBuilder OMONNXOps OMResultTypeInferenceOpInterface) | ||||
| 
 | ||||
| if (INCLUDE_ONNX_ML) | ||||
|   add_dependencies(OMBuilder OMMLONNXOps) | ||||
|  |  | |||
|  | @ -20,6 +20,8 @@ | |||
| #include <mpark/variant.hpp> | ||||
| namespace bstd = mpark; | ||||
| 
 | ||||
| #include "src/Interface/ResultTypeInferenceOpInterface.hpp" | ||||
| 
 | ||||
| #include "FrontendDialectTransformer.hpp" | ||||
| 
 | ||||
| namespace onnx_mlir { | ||||
|  | @ -296,6 +298,17 @@ private: | |||
| 
 | ||||
|     // TODO: Handle optional inputs.
 | ||||
|     auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes); | ||||
| 
 | ||||
|     // Type inference for results.
 | ||||
|     if (auto opWithTypeInference = | ||||
|             mlir::dyn_cast<mlir::ResultTypeInferenceOpInterface>( | ||||
|                 op.getOperation())) { | ||||
|       auto outTypes = opWithTypeInference.resultTypeInference(); | ||||
|       for (int i = 0; i < node.output().size(); i++) { | ||||
|         (*op.getODSResults(i).begin()).setType(outTypes[i]); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     for (int i = 0; i < node.output().size(); i++) { | ||||
|       frontend_symbols_.AddMapping( | ||||
|           legalize_name(node.output()[i]), *(op.getODSResults(i).begin())); | ||||
|  |  | |||
|  | @ -46,6 +46,7 @@ target_link_libraries(onnx-mlir | |||
|         OMShapeInferenceOpInterface | ||||
|         OMAttributePromotion | ||||
|         OMPromotableConstOperandsOpInterface | ||||
|         OMResultTypeInferenceOpInterface | ||||
|         OMElideConstants | ||||
|         OMElideKrnlGlobalConstants | ||||
|         OMKrnlToAffine | ||||
|  |  | |||
|  | @ -17,6 +17,7 @@ add_dependencies(OMMLONNXOps OMMLONNXOpsIncGen) | |||
| # Linking dependencies: | ||||
| add_dependencies(OMMLONNXOps | ||||
|         OMPromotableConstOperandsOpInterface | ||||
|         OMResultTypeInferenceOpInterface | ||||
|         OMShapeInferenceOpInterface) | ||||
| 
 | ||||
| add_onnx_mlir_dialect_doc(mlonnx MLONNXOps.td) | ||||
|  |  | |||
|  | @ -20,6 +20,7 @@ | |||
| #include "mlir/IR/StandardTypes.h" | ||||
| 
 | ||||
| #include "src/Interface/PromotableConstOperandsOpInterface.hpp" | ||||
| #include "src/Interface/ResultTypeInferenceOpInterface.hpp" | ||||
| #include "src/Interface/ShapeInferenceInterface.hpp" | ||||
| 
 | ||||
| namespace mlir { | ||||
|  |  | |||
|  | @ -27,6 +27,11 @@ include "src/Interface/ShapeInferenceInterface.td" | |||
| include "src/Interface/PromotableConstOperandsOpInterface.td" | ||||
| #endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE | ||||
| 
 | ||||
| #ifdef RESULT_TYPE_INFERENCE_OP_INTERFACE | ||||
| #else | ||||
| include "src/Interface/ResultTypeInferenceOpInterface.td" | ||||
| #endif // RESULT_TYPE_INFERENCE_OP_INTERFACE | ||||
| 
 | ||||
| def MLONNX_Dialect : Dialect { | ||||
|   let name = "mlonnx"; | ||||
|   let cppNamespace = ""; | ||||
|  |  | |||
|  | @ -18,6 +18,7 @@ add_dependencies(OMONNXOps OMONNXOpsIncGen) | |||
| # Linking dependencies: | ||||
| add_dependencies(OMONNXOps | ||||
|         OMPromotableConstOperandsOpInterface | ||||
|         OMResultTypeInferenceOpInterface | ||||
|         OMShapeInferenceOpInterface) | ||||
| 
 | ||||
| add_onnx_mlir_dialect_doc(onnx ONNXOps.td) | ||||
|  |  | |||
|  | @ -20,6 +20,7 @@ | |||
| #include "mlir/IR/StandardTypes.h" | ||||
| 
 | ||||
| #include "src/Interface/PromotableConstOperandsOpInterface.hpp" | ||||
| #include "src/Interface/ResultTypeInferenceOpInterface.hpp" | ||||
| #include "src/Interface/ShapeInferenceInterface.hpp" | ||||
| 
 | ||||
| namespace mlir { | ||||
|  |  | |||
|  | @ -27,6 +27,11 @@ include "src/Interface/ShapeInferenceInterface.td" | |||
| include "src/Interface/PromotableConstOperandsOpInterface.td" | ||||
| #endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE | ||||
| 
 | ||||
| #ifdef RESULT_TYPE_INFERENCE_OP_INTERFACE | ||||
| #else | ||||
| include "src/Interface/ResultTypeInferenceOpInterface.td" | ||||
| #endif // RESULT_TYPE_INFERENCE_OP_INTERFACE | ||||
| 
 | ||||
| def ONNX_Dialect : Dialect { | ||||
|   let name = "onnx"; | ||||
|   let cppNamespace = ""; | ||||
|  | @ -120,7 +125,7 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", | |||
|       return 1; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {0}; | ||||
|       return {20}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
|  | @ -156,7 +161,7 @@ def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode", | |||
|       return 1; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {0}; | ||||
|       return {20}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
|  |  | |||
|  | @ -559,7 +559,7 @@ def ONNXConcatFromSequenceOp:ONNX_Op<"ConcatFromSequence", | |||
| } | ||||
| 
 | ||||
| def ONNXConstantOp:ONNX_Op<"Constant", | ||||
|   [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { | ||||
|   [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"ResultTypeInferenceOpInterface">]> { | ||||
|   let summary = "ONNX Constant operation"; | ||||
|   let description = [{ | ||||
|   "A constant tensor. Exactly one of the two attributes, either value or sparse_value," | ||||
|  | @ -578,6 +578,15 @@ def ONNXConstantOp:ONNX_Op<"Constant", | |||
|     static std::vector<int> getTypeMap() { | ||||
|       return {-1}; | ||||
|     } | ||||
|     std::vector<mlir::Type> resultTypeInference() { | ||||
|       std::vector<mlir::Type> resultTypes; | ||||
|       if (auto attr = valueAttr()) { | ||||
|         resultTypes.push_back(attr.getType()); | ||||
|       } else if (auto attr = sparse_valueAttr()) { | ||||
|         resultTypes.push_back(attr.getType()); | ||||
|       } | ||||
|       return resultTypes; | ||||
|     } | ||||
|   }]; | ||||
|     let builders = [ | ||||
|     OpBuilder<"OpBuilder &builder, OperationState &state, Attribute sparse_value, Attribute value", [{ | ||||
|  | @ -589,7 +598,8 @@ def ONNXConstantOp:ONNX_Op<"Constant", | |||
|         build(builder, state, tensorType, sparse_value, value); | ||||
|       } | ||||
|     }]> | ||||
|     ];} | ||||
|     ]; | ||||
| } | ||||
| 
 | ||||
| def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape", | ||||
|   [NoSideEffect]> { | ||||
|  |  | |||
|  | @ -25,3 +25,17 @@ target_include_directories(OMShapeInferenceOpInterface | |||
|         PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} | ||||
|         ${ONNX_MLIR_SRC_ROOT}) | ||||
| add_dependencies(OMShapeInferenceOpInterface ShapeInferenceOpInterfaceIncGen) | ||||
| 
 | ||||
| set(LLVM_TARGET_DEFINITIONS ResultTypeInferenceOpInterface.td) | ||||
| onnx_mlir_tablegen(ResultTypeInferenceOpInterface.hpp.inc -gen-op-interface-decls) | ||||
| onnx_mlir_tablegen(ResultTypeInferenceOpInterface.cpp.inc -gen-op-interface-defs) | ||||
| add_public_tablegen_target(OMResultTypeInferenceOpInterfaceIncGen) | ||||
| 
 | ||||
| add_library(OMResultTypeInferenceOpInterface | ||||
|         ResultTypeInferenceOpInterface.hpp | ||||
|         ResultTypeInferenceOpInterface.cpp) | ||||
| target_include_directories(OMResultTypeInferenceOpInterface | ||||
|         PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} | ||||
|         ${ONNX_MLIR_SRC_ROOT}) | ||||
| add_dependencies(OMResultTypeInferenceOpInterface | ||||
|         OMResultTypeInferenceOpInterfaceIncGen) | ||||
|  |  | |||
|  | @ -0,0 +1,20 @@ | |||
| //===------------ ResultTypeInferenceOpInterface.cpp --------------===//
 | ||||
| //===--------- Infer Data Type for Results Interface Definition --------===//
 | ||||
| //
 | ||||
| // Copyright 2020 The IBM Research Authors.
 | ||||
| //
 | ||||
| // =============================================================================
 | ||||
| //
 | ||||
| // This file contains the implementation of the data type inference interface.
 | ||||
| //
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| #include "ResultTypeInferenceOpInterface.hpp" | ||||
| 
 | ||||
| using namespace mlir; | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Infer Data Type for Results Interface
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| #include "src/Interface/ResultTypeInferenceOpInterface.cpp.inc" | ||||
|  | @ -0,0 +1,25 @@ | |||
| //===------------ ResultTypeInferenceOpInterface.hpp --------------===//
 | ||||
| //===------- Infer Data Type for Result of Op Interface Definition -------===//
 | ||||
| //
 | ||||
| // Copyright 2020 The IBM Research Authors.
 | ||||
| //
 | ||||
| // =============================================================================
 | ||||
| //
 | ||||
| // This file contains the declaration of the data type reference for op
 | ||||
| // interface.
 | ||||
| //
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| #pragma once | ||||
| 
 | ||||
| #include <map> | ||||
| #include <string> | ||||
| 
 | ||||
| #include "mlir/IR/OpDefinition.h" | ||||
| 
 | ||||
| namespace mlir { | ||||
| 
 | ||||
| /// Include the auto-generated declarations.
 | ||||
| #include "src/Interface/ResultTypeInferenceOpInterface.hpp.inc" | ||||
| 
 | ||||
| } // end namespace mlir
 | ||||
|  | @ -0,0 +1,34 @@ | |||
| //===------------ ResultTypeInferenceOpInterface.td --------------===// | ||||
| //===--------- Infer Data Type for Results Interface Definition --------===// | ||||
| // | ||||
| // Copyright 2020 The IBM Research Authors. | ||||
| // | ||||
| // ============================================================================= | ||||
| // | ||||
| // This file contains the tablegen of the data type inference interface. | ||||
| // | ||||
| //===----------------------------------------------------------------------===// | ||||
| 
 | ||||
| #ifdef RESULT_TYPE_INFERENCE_OP_INTERFACE | ||||
| #else | ||||
| #define RESULT_TYPE_INFERENCE_OP_INTERFACE | ||||
| 
 | ||||
| #ifdef OP_BASE | ||||
| #else | ||||
| include "mlir/IR/OpBase.td" | ||||
| #endif // OP_BASE | ||||
| 
 | ||||
| def ResultTypeInferenceOpInterface : OpInterface<"ResultTypeInferenceOpInterface"> { | ||||
|   let description = [{ | ||||
|     Interface to access a registered method to infer the data types for  | ||||
|     the result of an operation | ||||
|   }]; | ||||
| 
 | ||||
|   let methods = [ | ||||
|       InterfaceMethod<"Infer output data type for this operation class.", | ||||
|           "std::vector<mlir::Type>", "resultTypeInference" | ||||
|       > | ||||
|   ]; | ||||
| } | ||||
| 
 | ||||
| #endif // RESULT_TYPE_INFERENCE_OP_INTERFACE | ||||
|  | @ -271,6 +271,17 @@ OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv'] | |||
| OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)], | ||||
|                                   "Pad": [("pads", 1), ("constant_value", 2)]} | ||||
| 
 | ||||
| # Interface for special handling of type inference | ||||
| # The common code are put into get_type_inference_func | ||||
| OpsWithResultTypeInference = { | ||||
|   "Constant": | ||||
|   '''if (auto attr = valueAttr()) { | ||||
|         resultTypes.push_back(attr.getType()); | ||||
|       } else if (auto attr = sparse_valueAttr()) { | ||||
|         resultTypes.push_back(attr.getType()); | ||||
|       }''' | ||||
| } | ||||
|         | ||||
| # Add an Op in this list if the Op needs result type deduction which is required | ||||
| # when writing declarative rewriting rules. Deduced type is always | ||||
| # an UnrankedTensorType whose element type is the same as the first operand's | ||||
|  | @ -634,6 +645,23 @@ def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx): | |||
| 
 | ||||
|     return s | ||||
| 
 | ||||
| def get_type_inference_func(s, indent, type_inference_code): | ||||
|     indent = inc_indent(indent) | ||||
| 
 | ||||
|     s += indent + "std::vector<mlir::Type> resultTypeInference() {" + "\n" | ||||
|     indent = inc_indent(indent) | ||||
|     s += indent + "std::vector<mlir::Type> resultTypes;" + "\n" | ||||
| 
 | ||||
|     s += indent + type_inference_code + '\n' | ||||
| 
 | ||||
|     s += indent + "return resultTypes;" + "\n" | ||||
|     indent = dec_indent(indent) | ||||
|     s += indent + "}" + "\n" | ||||
| 
 | ||||
|     indent = dec_indent(indent) | ||||
|     return s | ||||
|    | ||||
|    | ||||
| 
 | ||||
| def gen_op_def(schema): | ||||
|     indent = inc_indent() | ||||
|  | @ -648,6 +676,8 @@ def gen_op_def(schema): | |||
|         traits.append("DeclareOpInterfaceMethods<ShapeInferenceOpInterface>") | ||||
|     if schema.name in OpsWithPromotableConstOperands.keys(): | ||||
|         traits.append("OpInterface<\"PromotableConstOperandsOpInterface\">") | ||||
|     if schema.name in OpsWithResultTypeInference.keys(): | ||||
|         traits.append("OpInterface<\"ResultTypeInferenceOpInterface\">") | ||||
|     s += inc_indent(indent) + '[{}]> {{\n'.format(join_args(traits)) | ||||
| 
 | ||||
|     # Generate decl for canonicalizer. | ||||
|  | @ -739,10 +769,14 @@ def gen_op_def(schema): | |||
|         s = get_promotable_const_operands_func( | ||||
|             s, indent, OpsWithPromotableConstOperands[schema.name]) | ||||
| 
 | ||||
|     if schema.name in OpsWithResultTypeInference: | ||||
|         s = get_type_inference_func( | ||||
|             s, indent, OpsWithResultTypeInference[schema.name]) | ||||
| 
 | ||||
|     s += indent + '}];\n' | ||||
| 
 | ||||
|     if ( schema.name in custom_definition_misc) : | ||||
|         s += custom_definition_misc[schema.name] | ||||
|         s += custom_definition_misc[schema.name] + '\n' | ||||
| 
 | ||||
|     s += '}\n\n' | ||||
|     return s | ||||
|  | @ -852,6 +886,7 @@ def build_operator_schemas(): | |||
|                         print("Your onnx may be too old." | ||||
|                            "right version for opertion {} not found".format( | ||||
|                             schema.name)) | ||||
|                         sys.exit() | ||||
|             processed_supportmap.append((_support, processed_namemap)) | ||||
|         operator_schemas.append((domain, processed_supportmap)) | ||||
|     return operator_schemas | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue