From 5a542e78a0e7f7f424cc2c7acba5709dff9187d1 Mon Sep 17 00:00:00 2001 From: chentong319 Date: Mon, 25 May 2020 21:54:19 -0400 Subject: [PATCH] fix type inference (#144) * fix type inference for ConstantOp and MaxPoolSingleOut * modify interface * use OpInterface * change name to OpInterface * Builder dependence * Update CMakeLists.txt * Update CMakeLists.txt Co-authored-by: Tian Jin --- src/Builder/CMakeLists.txt | 2 +- src/Builder/FrontendDialectTransformer.cpp | 13 +++++++ src/CMakeLists.txt | 3 +- src/Dialect/MLONNX/CMakeLists.txt | 1 + src/Dialect/MLONNX/MLONNXOps.hpp | 1 + src/Dialect/MLONNX/MLONNXOps.td | 5 +++ src/Dialect/ONNX/CMakeLists.txt | 1 + src/Dialect/ONNX/ONNXOps.hpp | 1 + src/Dialect/ONNX/ONNXOps.td | 9 ++++- src/Dialect/ONNX/ONNXOps.td.inc | 14 ++++++- src/Interface/CMakeLists.txt | 16 +++++++- .../ResultTypeInferenceOpInterface.cpp | 20 ++++++++++ .../ResultTypeInferenceOpInterface.hpp | 25 +++++++++++++ .../ResultTypeInferenceOpInterface.td | 34 +++++++++++++++++ utils/gen_onnx_mlir.py | 37 ++++++++++++++++++- 15 files changed, 174 insertions(+), 8 deletions(-) create mode 100644 src/Interface/ResultTypeInferenceOpInterface.cpp create mode 100644 src/Interface/ResultTypeInferenceOpInterface.hpp create mode 100644 src/Interface/ResultTypeInferenceOpInterface.td diff --git a/src/Builder/CMakeLists.txt b/src/Builder/CMakeLists.txt index 1e87f7e..bf89da7 100644 --- a/src/Builder/CMakeLists.txt +++ b/src/Builder/CMakeLists.txt @@ -27,7 +27,7 @@ target_include_directories(OMBuilder # variable definitions when building onnx such as -DONNX_ML=1 -DONNX_NAMESPACE=onnx # will NOT be carried over when compiling FrontendDialectHelper.cpp, etc. so # the compilation will fail. -add_dependencies(OMBuilder OMONNXOps) +add_dependencies(OMBuilder OMONNXOps OMResultTypeInferenceOpInterface) if (INCLUDE_ONNX_ML) add_dependencies(OMBuilder OMMLONNXOps) diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index 932d59d..300b94b 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -20,6 +20,8 @@ #include namespace bstd = mpark; +#include "src/Interface/ResultTypeInferenceOpInterface.hpp" + #include "FrontendDialectTransformer.hpp" namespace onnx_mlir { @@ -296,6 +298,17 @@ private: // TODO: Handle optional inputs. auto op = builder_.create(UnknownLoc(), outputTypes, inputs, attributes); + + // Type inference for results. + if (auto opWithTypeInference = + mlir::dyn_cast( + op.getOperation())) { + auto outTypes = opWithTypeInference.resultTypeInference(); + for (int i = 0; i < node.output().size(); i++) { + (*op.getODSResults(i).begin()).setType(outTypes[i]); + } + } + for (int i = 0; i < node.output().size(); i++) { frontend_symbols_.AddMapping( legalize_name(node.output()[i]), *(op.getODSResults(i).begin())); diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d84b3b4..16f9b83 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -46,6 +46,7 @@ target_link_libraries(onnx-mlir OMShapeInferenceOpInterface OMAttributePromotion OMPromotableConstOperandsOpInterface + OMResultTypeInferenceOpInterface OMElideConstants OMElideKrnlGlobalConstants OMKrnlToAffine @@ -66,4 +67,4 @@ target_include_directories(onnx-mlir PRIVATE ${ONNX_MLIR_BIN_ROOT}) install(TARGETS onnx-mlir DESTINATION bin) install(FILES Runtime/DynMemRef.h DESTINATION include) -install(TARGETS cruntime DESTINATION lib) \ No newline at end of file +install(TARGETS cruntime DESTINATION lib) diff --git a/src/Dialect/MLONNX/CMakeLists.txt b/src/Dialect/MLONNX/CMakeLists.txt index da4e977..6f50aea 100644 --- a/src/Dialect/MLONNX/CMakeLists.txt +++ b/src/Dialect/MLONNX/CMakeLists.txt @@ -17,6 +17,7 @@ add_dependencies(OMMLONNXOps OMMLONNXOpsIncGen) # Linking dependencies: add_dependencies(OMMLONNXOps OMPromotableConstOperandsOpInterface + OMResultTypeInferenceOpInterface OMShapeInferenceOpInterface) add_onnx_mlir_dialect_doc(mlonnx MLONNXOps.td) diff --git a/src/Dialect/MLONNX/MLONNXOps.hpp b/src/Dialect/MLONNX/MLONNXOps.hpp index 3859756..5457a0b 100644 --- a/src/Dialect/MLONNX/MLONNXOps.hpp +++ b/src/Dialect/MLONNX/MLONNXOps.hpp @@ -20,6 +20,7 @@ #include "mlir/IR/StandardTypes.h" #include "src/Interface/PromotableConstOperandsOpInterface.hpp" +#include "src/Interface/ResultTypeInferenceOpInterface.hpp" #include "src/Interface/ShapeInferenceInterface.hpp" namespace mlir { diff --git a/src/Dialect/MLONNX/MLONNXOps.td b/src/Dialect/MLONNX/MLONNXOps.td index 713881a..296442e 100644 --- a/src/Dialect/MLONNX/MLONNXOps.td +++ b/src/Dialect/MLONNX/MLONNXOps.td @@ -27,6 +27,11 @@ include "src/Interface/ShapeInferenceInterface.td" include "src/Interface/PromotableConstOperandsOpInterface.td" #endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE +#ifdef RESULT_TYPE_INFERENCE_OP_INTERFACE +#else +include "src/Interface/ResultTypeInferenceOpInterface.td" +#endif // RESULT_TYPE_INFERENCE_OP_INTERFACE + def MLONNX_Dialect : Dialect { let name = "mlonnx"; let cppNamespace = ""; diff --git a/src/Dialect/ONNX/CMakeLists.txt b/src/Dialect/ONNX/CMakeLists.txt index 20c671d..8295a58 100644 --- a/src/Dialect/ONNX/CMakeLists.txt +++ b/src/Dialect/ONNX/CMakeLists.txt @@ -18,6 +18,7 @@ add_dependencies(OMONNXOps OMONNXOpsIncGen) # Linking dependencies: add_dependencies(OMONNXOps OMPromotableConstOperandsOpInterface + OMResultTypeInferenceOpInterface OMShapeInferenceOpInterface) add_onnx_mlir_dialect_doc(onnx ONNXOps.td) diff --git a/src/Dialect/ONNX/ONNXOps.hpp b/src/Dialect/ONNX/ONNXOps.hpp index 66c0e7e..2d9e871 100644 --- a/src/Dialect/ONNX/ONNXOps.hpp +++ b/src/Dialect/ONNX/ONNXOps.hpp @@ -20,6 +20,7 @@ #include "mlir/IR/StandardTypes.h" #include "src/Interface/PromotableConstOperandsOpInterface.hpp" +#include "src/Interface/ResultTypeInferenceOpInterface.hpp" #include "src/Interface/ShapeInferenceInterface.hpp" namespace mlir { diff --git a/src/Dialect/ONNX/ONNXOps.td b/src/Dialect/ONNX/ONNXOps.td index 8187ddb..9dff9e9 100644 --- a/src/Dialect/ONNX/ONNXOps.td +++ b/src/Dialect/ONNX/ONNXOps.td @@ -27,6 +27,11 @@ include "src/Interface/ShapeInferenceInterface.td" include "src/Interface/PromotableConstOperandsOpInterface.td" #endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE +#ifdef RESULT_TYPE_INFERENCE_OP_INTERFACE +#else +include "src/Interface/ResultTypeInferenceOpInterface.td" +#endif // RESULT_TYPE_INFERENCE_OP_INTERFACE + def ONNX_Dialect : Dialect { let name = "onnx"; let cppNamespace = ""; @@ -120,7 +125,7 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", return 1; } static std::vector getTypeMap() { - return {0}; + return {20}; } }]; } @@ -156,7 +161,7 @@ def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode", return 1; } static std::vector getTypeMap() { - return {0}; + return {20}; } }]; } diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index e49955c..e6a361e 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -559,7 +559,7 @@ def ONNXConcatFromSequenceOp:ONNX_Op<"ConcatFromSequence", } def ONNXConstantOp:ONNX_Op<"Constant", - [NoSideEffect, DeclareOpInterfaceMethods]> { + [NoSideEffect, DeclareOpInterfaceMethods, OpInterface<"ResultTypeInferenceOpInterface">]> { let summary = "ONNX Constant operation"; let description = [{ "A constant tensor. Exactly one of the two attributes, either value or sparse_value," @@ -578,6 +578,15 @@ def ONNXConstantOp:ONNX_Op<"Constant", static std::vector getTypeMap() { return {-1}; } + std::vector resultTypeInference() { + std::vector resultTypes; + if (auto attr = valueAttr()) { + resultTypes.push_back(attr.getType()); + } else if (auto attr = sparse_valueAttr()) { + resultTypes.push_back(attr.getType()); + } + return resultTypes; + } }]; let builders = [ OpBuilder<"OpBuilder &builder, OperationState &state, Attribute sparse_value, Attribute value", [{ @@ -589,7 +598,8 @@ def ONNXConstantOp:ONNX_Op<"Constant", build(builder, state, tensorType, sparse_value, value); } }]> - ];} + ]; +} def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape", [NoSideEffect]> { diff --git a/src/Interface/CMakeLists.txt b/src/Interface/CMakeLists.txt index e6ab929..476eeb4 100644 --- a/src/Interface/CMakeLists.txt +++ b/src/Interface/CMakeLists.txt @@ -24,4 +24,18 @@ add_library(OMShapeInferenceOpInterface target_include_directories(OMShapeInferenceOpInterface PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} ${ONNX_MLIR_SRC_ROOT}) -add_dependencies(OMShapeInferenceOpInterface ShapeInferenceOpInterfaceIncGen) \ No newline at end of file +add_dependencies(OMShapeInferenceOpInterface ShapeInferenceOpInterfaceIncGen) + +set(LLVM_TARGET_DEFINITIONS ResultTypeInferenceOpInterface.td) +onnx_mlir_tablegen(ResultTypeInferenceOpInterface.hpp.inc -gen-op-interface-decls) +onnx_mlir_tablegen(ResultTypeInferenceOpInterface.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(OMResultTypeInferenceOpInterfaceIncGen) + +add_library(OMResultTypeInferenceOpInterface + ResultTypeInferenceOpInterface.hpp + ResultTypeInferenceOpInterface.cpp) +target_include_directories(OMResultTypeInferenceOpInterface + PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} + ${ONNX_MLIR_SRC_ROOT}) +add_dependencies(OMResultTypeInferenceOpInterface + OMResultTypeInferenceOpInterfaceIncGen) diff --git a/src/Interface/ResultTypeInferenceOpInterface.cpp b/src/Interface/ResultTypeInferenceOpInterface.cpp new file mode 100644 index 0000000..0c1ae4b --- /dev/null +++ b/src/Interface/ResultTypeInferenceOpInterface.cpp @@ -0,0 +1,20 @@ +//===------------ ResultTypeInferenceOpInterface.cpp --------------===// +//===--------- Infer Data Type for Results Interface Definition --------===// +// +// Copyright 2020 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains the implementation of the data type inference interface. +// +//===----------------------------------------------------------------------===// + +#include "ResultTypeInferenceOpInterface.hpp" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Infer Data Type for Results Interface +//===----------------------------------------------------------------------===// + +#include "src/Interface/ResultTypeInferenceOpInterface.cpp.inc" diff --git a/src/Interface/ResultTypeInferenceOpInterface.hpp b/src/Interface/ResultTypeInferenceOpInterface.hpp new file mode 100644 index 0000000..6ea6426 --- /dev/null +++ b/src/Interface/ResultTypeInferenceOpInterface.hpp @@ -0,0 +1,25 @@ +//===------------ ResultTypeInferenceOpInterface.hpp --------------===// +//===------- Infer Data Type for Result of Op Interface Definition -------===// +// +// Copyright 2020 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains the declaration of the data type reference for op +// interface. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { + +/// Include the auto-generated declarations. +#include "src/Interface/ResultTypeInferenceOpInterface.hpp.inc" + +} // end namespace mlir diff --git a/src/Interface/ResultTypeInferenceOpInterface.td b/src/Interface/ResultTypeInferenceOpInterface.td new file mode 100644 index 0000000..378e2b0 --- /dev/null +++ b/src/Interface/ResultTypeInferenceOpInterface.td @@ -0,0 +1,34 @@ +//===------------ ResultTypeInferenceOpInterface.td --------------===// +//===--------- Infer Data Type for Results Interface Definition --------===// +// +// Copyright 2020 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains the tablegen of the data type inference interface. +// +//===----------------------------------------------------------------------===// + +#ifdef RESULT_TYPE_INFERENCE_OP_INTERFACE +#else +#define RESULT_TYPE_INFERENCE_OP_INTERFACE + +#ifdef OP_BASE +#else +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +def ResultTypeInferenceOpInterface : OpInterface<"ResultTypeInferenceOpInterface"> { + let description = [{ + Interface to access a registered method to infer the data types for + the result of an operation + }]; + + let methods = [ + InterfaceMethod<"Infer output data type for this operation class.", + "std::vector", "resultTypeInference" + > + ]; +} + +#endif // RESULT_TYPE_INFERENCE_OP_INTERFACE diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 76b1cbf..cce1734 100644 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -271,6 +271,17 @@ OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv'] OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)], "Pad": [("pads", 1), ("constant_value", 2)]} +# Interface for special handling of type inference +# The common code are put into get_type_inference_func +OpsWithResultTypeInference = { + "Constant": + '''if (auto attr = valueAttr()) { + resultTypes.push_back(attr.getType()); + } else if (auto attr = sparse_valueAttr()) { + resultTypes.push_back(attr.getType()); + }''' +} + # Add an Op in this list if the Op needs result type deduction which is required # when writing declarative rewriting rules. Deduced type is always # an UnrankedTensorType whose element type is the same as the first operand's @@ -634,6 +645,23 @@ def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx): return s +def get_type_inference_func(s, indent, type_inference_code): + indent = inc_indent(indent) + + s += indent + "std::vector resultTypeInference() {" + "\n" + indent = inc_indent(indent) + s += indent + "std::vector resultTypes;" + "\n" + + s += indent + type_inference_code + '\n' + + s += indent + "return resultTypes;" + "\n" + indent = dec_indent(indent) + s += indent + "}" + "\n" + + indent = dec_indent(indent) + return s + + def gen_op_def(schema): indent = inc_indent() @@ -648,6 +676,8 @@ def gen_op_def(schema): traits.append("DeclareOpInterfaceMethods") if schema.name in OpsWithPromotableConstOperands.keys(): traits.append("OpInterface<\"PromotableConstOperandsOpInterface\">") + if schema.name in OpsWithResultTypeInference.keys(): + traits.append("OpInterface<\"ResultTypeInferenceOpInterface\">") s += inc_indent(indent) + '[{}]> {{\n'.format(join_args(traits)) # Generate decl for canonicalizer. @@ -739,10 +769,14 @@ def gen_op_def(schema): s = get_promotable_const_operands_func( s, indent, OpsWithPromotableConstOperands[schema.name]) + if schema.name in OpsWithResultTypeInference: + s = get_type_inference_func( + s, indent, OpsWithResultTypeInference[schema.name]) + s += indent + '}];\n' if ( schema.name in custom_definition_misc) : - s += custom_definition_misc[schema.name] + s += custom_definition_misc[schema.name] + '\n' s += '}\n\n' return s @@ -852,6 +886,7 @@ def build_operator_schemas(): print("Your onnx may be too old." "right version for opertion {} not found".format( schema.name)) + sys.exit() processed_supportmap.append((_support, processed_namemap)) operator_schemas.append((domain, processed_supportmap)) return operator_schemas