fix type inference (#144)

* fix type inference for ConstantOp and MaxPoolSingleOut

* modify interface

* use OpInterface

* change name to OpInterface

* Builder dependence

* Update CMakeLists.txt

* Update CMakeLists.txt

Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
chentong319 2020-05-25 21:54:19 -04:00 committed by GitHub
parent 6099efd91b
commit 5a542e78a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 174 additions and 8 deletions

View File

@ -27,7 +27,7 @@ target_include_directories(OMBuilder
# variable definitions when building onnx such as -DONNX_ML=1 -DONNX_NAMESPACE=onnx # variable definitions when building onnx such as -DONNX_ML=1 -DONNX_NAMESPACE=onnx
# will NOT be carried over when compiling FrontendDialectHelper.cpp, etc. so # will NOT be carried over when compiling FrontendDialectHelper.cpp, etc. so
# the compilation will fail. # the compilation will fail.
add_dependencies(OMBuilder OMONNXOps) add_dependencies(OMBuilder OMONNXOps OMResultTypeInferenceOpInterface)
if (INCLUDE_ONNX_ML) if (INCLUDE_ONNX_ML)
add_dependencies(OMBuilder OMMLONNXOps) add_dependencies(OMBuilder OMMLONNXOps)

View File

@ -20,6 +20,8 @@
#include <mpark/variant.hpp> #include <mpark/variant.hpp>
namespace bstd = mpark; namespace bstd = mpark;
#include "src/Interface/ResultTypeInferenceOpInterface.hpp"
#include "FrontendDialectTransformer.hpp" #include "FrontendDialectTransformer.hpp"
namespace onnx_mlir { namespace onnx_mlir {
@ -296,6 +298,17 @@ private:
// TODO: Handle optional inputs. // TODO: Handle optional inputs.
auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes); auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
// Type inference for results.
if (auto opWithTypeInference =
mlir::dyn_cast<mlir::ResultTypeInferenceOpInterface>(
op.getOperation())) {
auto outTypes = opWithTypeInference.resultTypeInference();
for (int i = 0; i < node.output().size(); i++) {
(*op.getODSResults(i).begin()).setType(outTypes[i]);
}
}
for (int i = 0; i < node.output().size(); i++) { for (int i = 0; i < node.output().size(); i++) {
frontend_symbols_.AddMapping( frontend_symbols_.AddMapping(
legalize_name(node.output()[i]), *(op.getODSResults(i).begin())); legalize_name(node.output()[i]), *(op.getODSResults(i).begin()));

View File

@ -46,6 +46,7 @@ target_link_libraries(onnx-mlir
OMShapeInferenceOpInterface OMShapeInferenceOpInterface
OMAttributePromotion OMAttributePromotion
OMPromotableConstOperandsOpInterface OMPromotableConstOperandsOpInterface
OMResultTypeInferenceOpInterface
OMElideConstants OMElideConstants
OMElideKrnlGlobalConstants OMElideKrnlGlobalConstants
OMKrnlToAffine OMKrnlToAffine

View File

@ -17,6 +17,7 @@ add_dependencies(OMMLONNXOps OMMLONNXOpsIncGen)
# Linking dependencies: # Linking dependencies:
add_dependencies(OMMLONNXOps add_dependencies(OMMLONNXOps
OMPromotableConstOperandsOpInterface OMPromotableConstOperandsOpInterface
OMResultTypeInferenceOpInterface
OMShapeInferenceOpInterface) OMShapeInferenceOpInterface)
add_onnx_mlir_dialect_doc(mlonnx MLONNXOps.td) add_onnx_mlir_dialect_doc(mlonnx MLONNXOps.td)

View File

@ -20,6 +20,7 @@
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "src/Interface/PromotableConstOperandsOpInterface.hpp" #include "src/Interface/PromotableConstOperandsOpInterface.hpp"
#include "src/Interface/ResultTypeInferenceOpInterface.hpp"
#include "src/Interface/ShapeInferenceInterface.hpp" #include "src/Interface/ShapeInferenceInterface.hpp"
namespace mlir { namespace mlir {

View File

@ -27,6 +27,11 @@ include "src/Interface/ShapeInferenceInterface.td"
include "src/Interface/PromotableConstOperandsOpInterface.td" include "src/Interface/PromotableConstOperandsOpInterface.td"
#endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE #endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
#ifdef RESULT_TYPE_INFERENCE_OP_INTERFACE
#else
include "src/Interface/ResultTypeInferenceOpInterface.td"
#endif // RESULT_TYPE_INFERENCE_OP_INTERFACE
def MLONNX_Dialect : Dialect { def MLONNX_Dialect : Dialect {
let name = "mlonnx"; let name = "mlonnx";
let cppNamespace = ""; let cppNamespace = "";

View File

@ -18,6 +18,7 @@ add_dependencies(OMONNXOps OMONNXOpsIncGen)
# Linking dependencies: # Linking dependencies:
add_dependencies(OMONNXOps add_dependencies(OMONNXOps
OMPromotableConstOperandsOpInterface OMPromotableConstOperandsOpInterface
OMResultTypeInferenceOpInterface
OMShapeInferenceOpInterface) OMShapeInferenceOpInterface)
add_onnx_mlir_dialect_doc(onnx ONNXOps.td) add_onnx_mlir_dialect_doc(onnx ONNXOps.td)

View File

@ -20,6 +20,7 @@
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "src/Interface/PromotableConstOperandsOpInterface.hpp" #include "src/Interface/PromotableConstOperandsOpInterface.hpp"
#include "src/Interface/ResultTypeInferenceOpInterface.hpp"
#include "src/Interface/ShapeInferenceInterface.hpp" #include "src/Interface/ShapeInferenceInterface.hpp"
namespace mlir { namespace mlir {

View File

@ -27,6 +27,11 @@ include "src/Interface/ShapeInferenceInterface.td"
include "src/Interface/PromotableConstOperandsOpInterface.td" include "src/Interface/PromotableConstOperandsOpInterface.td"
#endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE #endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
#ifdef RESULT_TYPE_INFERENCE_OP_INTERFACE
#else
include "src/Interface/ResultTypeInferenceOpInterface.td"
#endif // RESULT_TYPE_INFERENCE_OP_INTERFACE
def ONNX_Dialect : Dialect { def ONNX_Dialect : Dialect {
let name = "onnx"; let name = "onnx";
let cppNamespace = ""; let cppNamespace = "";
@ -120,7 +125,7 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
return 1; return 1;
} }
static std::vector<int> getTypeMap() { static std::vector<int> getTypeMap() {
return {0}; return {20};
} }
}]; }];
} }
@ -156,7 +161,7 @@ def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode",
return 1; return 1;
} }
static std::vector<int> getTypeMap() { static std::vector<int> getTypeMap() {
return {0}; return {20};
} }
}]; }];
} }

View File

@ -559,7 +559,7 @@ def ONNXConcatFromSequenceOp:ONNX_Op<"ConcatFromSequence",
} }
def ONNXConstantOp:ONNX_Op<"Constant", def ONNXConstantOp:ONNX_Op<"Constant",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"ResultTypeInferenceOpInterface">]> {
let summary = "ONNX Constant operation"; let summary = "ONNX Constant operation";
let description = [{ let description = [{
"A constant tensor. Exactly one of the two attributes, either value or sparse_value," "A constant tensor. Exactly one of the two attributes, either value or sparse_value,"
@ -578,6 +578,15 @@ def ONNXConstantOp:ONNX_Op<"Constant",
static std::vector<int> getTypeMap() { static std::vector<int> getTypeMap() {
return {-1}; return {-1};
} }
std::vector<mlir::Type> resultTypeInference() {
std::vector<mlir::Type> resultTypes;
if (auto attr = valueAttr()) {
resultTypes.push_back(attr.getType());
} else if (auto attr = sparse_valueAttr()) {
resultTypes.push_back(attr.getType());
}
return resultTypes;
}
}]; }];
let builders = [ let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Attribute sparse_value, Attribute value", [{ OpBuilder<"OpBuilder &builder, OperationState &state, Attribute sparse_value, Attribute value", [{
@ -589,7 +598,8 @@ def ONNXConstantOp:ONNX_Op<"Constant",
build(builder, state, tensorType, sparse_value, value); build(builder, state, tensorType, sparse_value, value);
} }
}]> }]>
];} ];
}
def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape", def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape",
[NoSideEffect]> { [NoSideEffect]> {

View File

@ -25,3 +25,17 @@ target_include_directories(OMShapeInferenceOpInterface
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
${ONNX_MLIR_SRC_ROOT}) ${ONNX_MLIR_SRC_ROOT})
add_dependencies(OMShapeInferenceOpInterface ShapeInferenceOpInterfaceIncGen) add_dependencies(OMShapeInferenceOpInterface ShapeInferenceOpInterfaceIncGen)
set(LLVM_TARGET_DEFINITIONS ResultTypeInferenceOpInterface.td)
onnx_mlir_tablegen(ResultTypeInferenceOpInterface.hpp.inc -gen-op-interface-decls)
onnx_mlir_tablegen(ResultTypeInferenceOpInterface.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(OMResultTypeInferenceOpInterfaceIncGen)
add_library(OMResultTypeInferenceOpInterface
ResultTypeInferenceOpInterface.hpp
ResultTypeInferenceOpInterface.cpp)
target_include_directories(OMResultTypeInferenceOpInterface
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
${ONNX_MLIR_SRC_ROOT})
add_dependencies(OMResultTypeInferenceOpInterface
OMResultTypeInferenceOpInterfaceIncGen)

View File

@ -0,0 +1,20 @@
//===------------ ResultTypeInferenceOpInterface.cpp --------------===//
//===--------- Infer Data Type for Results Interface Definition --------===//
//
// Copyright 2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains the implementation of the data type inference interface.
//
//===----------------------------------------------------------------------===//
#include "ResultTypeInferenceOpInterface.hpp"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Infer Data Type for Results Interface
//===----------------------------------------------------------------------===//
#include "src/Interface/ResultTypeInferenceOpInterface.cpp.inc"

View File

@ -0,0 +1,25 @@
//===------------ ResultTypeInferenceOpInterface.hpp --------------===//
//===------- Infer Data Type for Result of Op Interface Definition -------===//
//
// Copyright 2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains the declaration of the data type reference for op
// interface.
//
//===----------------------------------------------------------------------===//
#pragma once
#include <map>
#include <string>
#include "mlir/IR/OpDefinition.h"
namespace mlir {
/// Include the auto-generated declarations.
#include "src/Interface/ResultTypeInferenceOpInterface.hpp.inc"
} // end namespace mlir

View File

@ -0,0 +1,34 @@
//===------------ ResultTypeInferenceOpInterface.td --------------===//
//===--------- Infer Data Type for Results Interface Definition --------===//
//
// Copyright 2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains the tablegen of the data type inference interface.
//
//===----------------------------------------------------------------------===//
#ifdef RESULT_TYPE_INFERENCE_OP_INTERFACE
#else
#define RESULT_TYPE_INFERENCE_OP_INTERFACE
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
def ResultTypeInferenceOpInterface : OpInterface<"ResultTypeInferenceOpInterface"> {
let description = [{
Interface to access a registered method to infer the data types for
the result of an operation
}];
let methods = [
InterfaceMethod<"Infer output data type for this operation class.",
"std::vector<mlir::Type>", "resultTypeInference"
>
];
}
#endif // RESULT_TYPE_INFERENCE_OP_INTERFACE

View File

@ -271,6 +271,17 @@ OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv']
OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)], OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)],
"Pad": [("pads", 1), ("constant_value", 2)]} "Pad": [("pads", 1), ("constant_value", 2)]}
# Interface for special handling of type inference
# The common code are put into get_type_inference_func
OpsWithResultTypeInference = {
"Constant":
'''if (auto attr = valueAttr()) {
resultTypes.push_back(attr.getType());
} else if (auto attr = sparse_valueAttr()) {
resultTypes.push_back(attr.getType());
}'''
}
# Add an Op in this list if the Op needs result type deduction which is required # Add an Op in this list if the Op needs result type deduction which is required
# when writing declarative rewriting rules. Deduced type is always # when writing declarative rewriting rules. Deduced type is always
# an UnrankedTensorType whose element type is the same as the first operand's # an UnrankedTensorType whose element type is the same as the first operand's
@ -634,6 +645,23 @@ def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx):
return s return s
def get_type_inference_func(s, indent, type_inference_code):
indent = inc_indent(indent)
s += indent + "std::vector<mlir::Type> resultTypeInference() {" + "\n"
indent = inc_indent(indent)
s += indent + "std::vector<mlir::Type> resultTypes;" + "\n"
s += indent + type_inference_code + '\n'
s += indent + "return resultTypes;" + "\n"
indent = dec_indent(indent)
s += indent + "}" + "\n"
indent = dec_indent(indent)
return s
def gen_op_def(schema): def gen_op_def(schema):
indent = inc_indent() indent = inc_indent()
@ -648,6 +676,8 @@ def gen_op_def(schema):
traits.append("DeclareOpInterfaceMethods<ShapeInferenceOpInterface>") traits.append("DeclareOpInterfaceMethods<ShapeInferenceOpInterface>")
if schema.name in OpsWithPromotableConstOperands.keys(): if schema.name in OpsWithPromotableConstOperands.keys():
traits.append("OpInterface<\"PromotableConstOperandsOpInterface\">") traits.append("OpInterface<\"PromotableConstOperandsOpInterface\">")
if schema.name in OpsWithResultTypeInference.keys():
traits.append("OpInterface<\"ResultTypeInferenceOpInterface\">")
s += inc_indent(indent) + '[{}]> {{\n'.format(join_args(traits)) s += inc_indent(indent) + '[{}]> {{\n'.format(join_args(traits))
# Generate decl for canonicalizer. # Generate decl for canonicalizer.
@ -739,10 +769,14 @@ def gen_op_def(schema):
s = get_promotable_const_operands_func( s = get_promotable_const_operands_func(
s, indent, OpsWithPromotableConstOperands[schema.name]) s, indent, OpsWithPromotableConstOperands[schema.name])
if schema.name in OpsWithResultTypeInference:
s = get_type_inference_func(
s, indent, OpsWithResultTypeInference[schema.name])
s += indent + '}];\n' s += indent + '}];\n'
if ( schema.name in custom_definition_misc) : if ( schema.name in custom_definition_misc) :
s += custom_definition_misc[schema.name] s += custom_definition_misc[schema.name] + '\n'
s += '}\n\n' s += '}\n\n'
return s return s
@ -852,6 +886,7 @@ def build_operator_schemas():
print("Your onnx may be too old." print("Your onnx may be too old."
"right version for opertion {} not found".format( "right version for opertion {} not found".format(
schema.name)) schema.name))
sys.exit()
processed_supportmap.append((_support, processed_namemap)) processed_supportmap.append((_support, processed_namemap))
operator_schemas.append((domain, processed_supportmap)) operator_schemas.append((domain, processed_supportmap))
return operator_schemas return operator_schemas