fix type inference (#144)
* fix type inference for ConstantOp and MaxPoolSingleOut * modify interface * use OpInterface * change name to OpInterface * Builder dependence * Update CMakeLists.txt * Update CMakeLists.txt Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
6099efd91b
commit
5a542e78a0
|
@ -27,7 +27,7 @@ target_include_directories(OMBuilder
|
||||||
# variable definitions when building onnx such as -DONNX_ML=1 -DONNX_NAMESPACE=onnx
|
# variable definitions when building onnx such as -DONNX_ML=1 -DONNX_NAMESPACE=onnx
|
||||||
# will NOT be carried over when compiling FrontendDialectHelper.cpp, etc. so
|
# will NOT be carried over when compiling FrontendDialectHelper.cpp, etc. so
|
||||||
# the compilation will fail.
|
# the compilation will fail.
|
||||||
add_dependencies(OMBuilder OMONNXOps)
|
add_dependencies(OMBuilder OMONNXOps OMResultTypeInferenceOpInterface)
|
||||||
|
|
||||||
if (INCLUDE_ONNX_ML)
|
if (INCLUDE_ONNX_ML)
|
||||||
add_dependencies(OMBuilder OMMLONNXOps)
|
add_dependencies(OMBuilder OMMLONNXOps)
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
#include <mpark/variant.hpp>
|
#include <mpark/variant.hpp>
|
||||||
namespace bstd = mpark;
|
namespace bstd = mpark;
|
||||||
|
|
||||||
|
#include "src/Interface/ResultTypeInferenceOpInterface.hpp"
|
||||||
|
|
||||||
#include "FrontendDialectTransformer.hpp"
|
#include "FrontendDialectTransformer.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
@ -296,6 +298,17 @@ private:
|
||||||
|
|
||||||
// TODO: Handle optional inputs.
|
// TODO: Handle optional inputs.
|
||||||
auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
|
auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
|
||||||
|
|
||||||
|
// Type inference for results.
|
||||||
|
if (auto opWithTypeInference =
|
||||||
|
mlir::dyn_cast<mlir::ResultTypeInferenceOpInterface>(
|
||||||
|
op.getOperation())) {
|
||||||
|
auto outTypes = opWithTypeInference.resultTypeInference();
|
||||||
|
for (int i = 0; i < node.output().size(); i++) {
|
||||||
|
(*op.getODSResults(i).begin()).setType(outTypes[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < node.output().size(); i++) {
|
for (int i = 0; i < node.output().size(); i++) {
|
||||||
frontend_symbols_.AddMapping(
|
frontend_symbols_.AddMapping(
|
||||||
legalize_name(node.output()[i]), *(op.getODSResults(i).begin()));
|
legalize_name(node.output()[i]), *(op.getODSResults(i).begin()));
|
||||||
|
|
|
@ -46,6 +46,7 @@ target_link_libraries(onnx-mlir
|
||||||
OMShapeInferenceOpInterface
|
OMShapeInferenceOpInterface
|
||||||
OMAttributePromotion
|
OMAttributePromotion
|
||||||
OMPromotableConstOperandsOpInterface
|
OMPromotableConstOperandsOpInterface
|
||||||
|
OMResultTypeInferenceOpInterface
|
||||||
OMElideConstants
|
OMElideConstants
|
||||||
OMElideKrnlGlobalConstants
|
OMElideKrnlGlobalConstants
|
||||||
OMKrnlToAffine
|
OMKrnlToAffine
|
||||||
|
|
|
@ -17,6 +17,7 @@ add_dependencies(OMMLONNXOps OMMLONNXOpsIncGen)
|
||||||
# Linking dependencies:
|
# Linking dependencies:
|
||||||
add_dependencies(OMMLONNXOps
|
add_dependencies(OMMLONNXOps
|
||||||
OMPromotableConstOperandsOpInterface
|
OMPromotableConstOperandsOpInterface
|
||||||
|
OMResultTypeInferenceOpInterface
|
||||||
OMShapeInferenceOpInterface)
|
OMShapeInferenceOpInterface)
|
||||||
|
|
||||||
add_onnx_mlir_dialect_doc(mlonnx MLONNXOps.td)
|
add_onnx_mlir_dialect_doc(mlonnx MLONNXOps.td)
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
|
||||||
#include "src/Interface/PromotableConstOperandsOpInterface.hpp"
|
#include "src/Interface/PromotableConstOperandsOpInterface.hpp"
|
||||||
|
#include "src/Interface/ResultTypeInferenceOpInterface.hpp"
|
||||||
#include "src/Interface/ShapeInferenceInterface.hpp"
|
#include "src/Interface/ShapeInferenceInterface.hpp"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
|
@ -27,6 +27,11 @@ include "src/Interface/ShapeInferenceInterface.td"
|
||||||
include "src/Interface/PromotableConstOperandsOpInterface.td"
|
include "src/Interface/PromotableConstOperandsOpInterface.td"
|
||||||
#endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
#endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
||||||
|
|
||||||
|
#ifdef RESULT_TYPE_INFERENCE_OP_INTERFACE
|
||||||
|
#else
|
||||||
|
include "src/Interface/ResultTypeInferenceOpInterface.td"
|
||||||
|
#endif // RESULT_TYPE_INFERENCE_OP_INTERFACE
|
||||||
|
|
||||||
def MLONNX_Dialect : Dialect {
|
def MLONNX_Dialect : Dialect {
|
||||||
let name = "mlonnx";
|
let name = "mlonnx";
|
||||||
let cppNamespace = "";
|
let cppNamespace = "";
|
||||||
|
|
|
@ -18,6 +18,7 @@ add_dependencies(OMONNXOps OMONNXOpsIncGen)
|
||||||
# Linking dependencies:
|
# Linking dependencies:
|
||||||
add_dependencies(OMONNXOps
|
add_dependencies(OMONNXOps
|
||||||
OMPromotableConstOperandsOpInterface
|
OMPromotableConstOperandsOpInterface
|
||||||
|
OMResultTypeInferenceOpInterface
|
||||||
OMShapeInferenceOpInterface)
|
OMShapeInferenceOpInterface)
|
||||||
|
|
||||||
add_onnx_mlir_dialect_doc(onnx ONNXOps.td)
|
add_onnx_mlir_dialect_doc(onnx ONNXOps.td)
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
|
||||||
#include "src/Interface/PromotableConstOperandsOpInterface.hpp"
|
#include "src/Interface/PromotableConstOperandsOpInterface.hpp"
|
||||||
|
#include "src/Interface/ResultTypeInferenceOpInterface.hpp"
|
||||||
#include "src/Interface/ShapeInferenceInterface.hpp"
|
#include "src/Interface/ShapeInferenceInterface.hpp"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
|
@ -27,6 +27,11 @@ include "src/Interface/ShapeInferenceInterface.td"
|
||||||
include "src/Interface/PromotableConstOperandsOpInterface.td"
|
include "src/Interface/PromotableConstOperandsOpInterface.td"
|
||||||
#endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
#endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
||||||
|
|
||||||
|
#ifdef RESULT_TYPE_INFERENCE_OP_INTERFACE
|
||||||
|
#else
|
||||||
|
include "src/Interface/ResultTypeInferenceOpInterface.td"
|
||||||
|
#endif // RESULT_TYPE_INFERENCE_OP_INTERFACE
|
||||||
|
|
||||||
def ONNX_Dialect : Dialect {
|
def ONNX_Dialect : Dialect {
|
||||||
let name = "onnx";
|
let name = "onnx";
|
||||||
let cppNamespace = "";
|
let cppNamespace = "";
|
||||||
|
@ -120,7 +125,7 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
static std::vector<int> getTypeMap() {
|
static std::vector<int> getTypeMap() {
|
||||||
return {0};
|
return {20};
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
@ -156,7 +161,7 @@ def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode",
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
static std::vector<int> getTypeMap() {
|
static std::vector<int> getTypeMap() {
|
||||||
return {0};
|
return {20};
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
|
@ -559,7 +559,7 @@ def ONNXConcatFromSequenceOp:ONNX_Op<"ConcatFromSequence",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXConstantOp:ONNX_Op<"Constant",
|
def ONNXConstantOp:ONNX_Op<"Constant",
|
||||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"ResultTypeInferenceOpInterface">]> {
|
||||||
let summary = "ONNX Constant operation";
|
let summary = "ONNX Constant operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"A constant tensor. Exactly one of the two attributes, either value or sparse_value,"
|
"A constant tensor. Exactly one of the two attributes, either value or sparse_value,"
|
||||||
|
@ -578,6 +578,15 @@ def ONNXConstantOp:ONNX_Op<"Constant",
|
||||||
static std::vector<int> getTypeMap() {
|
static std::vector<int> getTypeMap() {
|
||||||
return {-1};
|
return {-1};
|
||||||
}
|
}
|
||||||
|
std::vector<mlir::Type> resultTypeInference() {
|
||||||
|
std::vector<mlir::Type> resultTypes;
|
||||||
|
if (auto attr = valueAttr()) {
|
||||||
|
resultTypes.push_back(attr.getType());
|
||||||
|
} else if (auto attr = sparse_valueAttr()) {
|
||||||
|
resultTypes.push_back(attr.getType());
|
||||||
|
}
|
||||||
|
return resultTypes;
|
||||||
|
}
|
||||||
}];
|
}];
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, Attribute sparse_value, Attribute value", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, Attribute sparse_value, Attribute value", [{
|
||||||
|
@ -589,7 +598,8 @@ def ONNXConstantOp:ONNX_Op<"Constant",
|
||||||
build(builder, state, tensorType, sparse_value, value);
|
build(builder, state, tensorType, sparse_value, value);
|
||||||
}
|
}
|
||||||
}]>
|
}]>
|
||||||
];}
|
];
|
||||||
|
}
|
||||||
|
|
||||||
def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape",
|
def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect]> {
|
||||||
|
|
|
@ -25,3 +25,17 @@ target_include_directories(OMShapeInferenceOpInterface
|
||||||
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
|
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
|
||||||
${ONNX_MLIR_SRC_ROOT})
|
${ONNX_MLIR_SRC_ROOT})
|
||||||
add_dependencies(OMShapeInferenceOpInterface ShapeInferenceOpInterfaceIncGen)
|
add_dependencies(OMShapeInferenceOpInterface ShapeInferenceOpInterfaceIncGen)
|
||||||
|
|
||||||
|
set(LLVM_TARGET_DEFINITIONS ResultTypeInferenceOpInterface.td)
|
||||||
|
onnx_mlir_tablegen(ResultTypeInferenceOpInterface.hpp.inc -gen-op-interface-decls)
|
||||||
|
onnx_mlir_tablegen(ResultTypeInferenceOpInterface.cpp.inc -gen-op-interface-defs)
|
||||||
|
add_public_tablegen_target(OMResultTypeInferenceOpInterfaceIncGen)
|
||||||
|
|
||||||
|
add_library(OMResultTypeInferenceOpInterface
|
||||||
|
ResultTypeInferenceOpInterface.hpp
|
||||||
|
ResultTypeInferenceOpInterface.cpp)
|
||||||
|
target_include_directories(OMResultTypeInferenceOpInterface
|
||||||
|
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
|
||||||
|
${ONNX_MLIR_SRC_ROOT})
|
||||||
|
add_dependencies(OMResultTypeInferenceOpInterface
|
||||||
|
OMResultTypeInferenceOpInterfaceIncGen)
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
//===------------ ResultTypeInferenceOpInterface.cpp --------------===//
|
||||||
|
//===--------- Infer Data Type for Results Interface Definition --------===//
|
||||||
|
//
|
||||||
|
// Copyright 2020 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file contains the implementation of the data type inference interface.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "ResultTypeInferenceOpInterface.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Infer Data Type for Results Interface
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/Interface/ResultTypeInferenceOpInterface.cpp.inc"
|
|
@ -0,0 +1,25 @@
|
||||||
|
//===------------ ResultTypeInferenceOpInterface.hpp --------------===//
|
||||||
|
//===------- Infer Data Type for Result of Op Interface Definition -------===//
|
||||||
|
//
|
||||||
|
// Copyright 2020 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file contains the declaration of the data type reference for op
|
||||||
|
// interface.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
/// Include the auto-generated declarations.
|
||||||
|
#include "src/Interface/ResultTypeInferenceOpInterface.hpp.inc"
|
||||||
|
|
||||||
|
} // end namespace mlir
|
|
@ -0,0 +1,34 @@
|
||||||
|
//===------------ ResultTypeInferenceOpInterface.td --------------===//
|
||||||
|
//===--------- Infer Data Type for Results Interface Definition --------===//
|
||||||
|
//
|
||||||
|
// Copyright 2020 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file contains the tablegen of the data type inference interface.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifdef RESULT_TYPE_INFERENCE_OP_INTERFACE
|
||||||
|
#else
|
||||||
|
#define RESULT_TYPE_INFERENCE_OP_INTERFACE
|
||||||
|
|
||||||
|
#ifdef OP_BASE
|
||||||
|
#else
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
#endif // OP_BASE
|
||||||
|
|
||||||
|
def ResultTypeInferenceOpInterface : OpInterface<"ResultTypeInferenceOpInterface"> {
|
||||||
|
let description = [{
|
||||||
|
Interface to access a registered method to infer the data types for
|
||||||
|
the result of an operation
|
||||||
|
}];
|
||||||
|
|
||||||
|
let methods = [
|
||||||
|
InterfaceMethod<"Infer output data type for this operation class.",
|
||||||
|
"std::vector<mlir::Type>", "resultTypeInference"
|
||||||
|
>
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // RESULT_TYPE_INFERENCE_OP_INTERFACE
|
|
@ -271,6 +271,17 @@ OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv']
|
||||||
OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)],
|
OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)],
|
||||||
"Pad": [("pads", 1), ("constant_value", 2)]}
|
"Pad": [("pads", 1), ("constant_value", 2)]}
|
||||||
|
|
||||||
|
# Interface for special handling of type inference
|
||||||
|
# The common code are put into get_type_inference_func
|
||||||
|
OpsWithResultTypeInference = {
|
||||||
|
"Constant":
|
||||||
|
'''if (auto attr = valueAttr()) {
|
||||||
|
resultTypes.push_back(attr.getType());
|
||||||
|
} else if (auto attr = sparse_valueAttr()) {
|
||||||
|
resultTypes.push_back(attr.getType());
|
||||||
|
}'''
|
||||||
|
}
|
||||||
|
|
||||||
# Add an Op in this list if the Op needs result type deduction which is required
|
# Add an Op in this list if the Op needs result type deduction which is required
|
||||||
# when writing declarative rewriting rules. Deduced type is always
|
# when writing declarative rewriting rules. Deduced type is always
|
||||||
# an UnrankedTensorType whose element type is the same as the first operand's
|
# an UnrankedTensorType whose element type is the same as the first operand's
|
||||||
|
@ -634,6 +645,23 @@ def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx):
|
||||||
|
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
def get_type_inference_func(s, indent, type_inference_code):
|
||||||
|
indent = inc_indent(indent)
|
||||||
|
|
||||||
|
s += indent + "std::vector<mlir::Type> resultTypeInference() {" + "\n"
|
||||||
|
indent = inc_indent(indent)
|
||||||
|
s += indent + "std::vector<mlir::Type> resultTypes;" + "\n"
|
||||||
|
|
||||||
|
s += indent + type_inference_code + '\n'
|
||||||
|
|
||||||
|
s += indent + "return resultTypes;" + "\n"
|
||||||
|
indent = dec_indent(indent)
|
||||||
|
s += indent + "}" + "\n"
|
||||||
|
|
||||||
|
indent = dec_indent(indent)
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def gen_op_def(schema):
|
def gen_op_def(schema):
|
||||||
indent = inc_indent()
|
indent = inc_indent()
|
||||||
|
@ -648,6 +676,8 @@ def gen_op_def(schema):
|
||||||
traits.append("DeclareOpInterfaceMethods<ShapeInferenceOpInterface>")
|
traits.append("DeclareOpInterfaceMethods<ShapeInferenceOpInterface>")
|
||||||
if schema.name in OpsWithPromotableConstOperands.keys():
|
if schema.name in OpsWithPromotableConstOperands.keys():
|
||||||
traits.append("OpInterface<\"PromotableConstOperandsOpInterface\">")
|
traits.append("OpInterface<\"PromotableConstOperandsOpInterface\">")
|
||||||
|
if schema.name in OpsWithResultTypeInference.keys():
|
||||||
|
traits.append("OpInterface<\"ResultTypeInferenceOpInterface\">")
|
||||||
s += inc_indent(indent) + '[{}]> {{\n'.format(join_args(traits))
|
s += inc_indent(indent) + '[{}]> {{\n'.format(join_args(traits))
|
||||||
|
|
||||||
# Generate decl for canonicalizer.
|
# Generate decl for canonicalizer.
|
||||||
|
@ -739,10 +769,14 @@ def gen_op_def(schema):
|
||||||
s = get_promotable_const_operands_func(
|
s = get_promotable_const_operands_func(
|
||||||
s, indent, OpsWithPromotableConstOperands[schema.name])
|
s, indent, OpsWithPromotableConstOperands[schema.name])
|
||||||
|
|
||||||
|
if schema.name in OpsWithResultTypeInference:
|
||||||
|
s = get_type_inference_func(
|
||||||
|
s, indent, OpsWithResultTypeInference[schema.name])
|
||||||
|
|
||||||
s += indent + '}];\n'
|
s += indent + '}];\n'
|
||||||
|
|
||||||
if ( schema.name in custom_definition_misc) :
|
if ( schema.name in custom_definition_misc) :
|
||||||
s += custom_definition_misc[schema.name]
|
s += custom_definition_misc[schema.name] + '\n'
|
||||||
|
|
||||||
s += '}\n\n'
|
s += '}\n\n'
|
||||||
return s
|
return s
|
||||||
|
@ -852,6 +886,7 @@ def build_operator_schemas():
|
||||||
print("Your onnx may be too old."
|
print("Your onnx may be too old."
|
||||||
"right version for opertion {} not found".format(
|
"right version for opertion {} not found".format(
|
||||||
schema.name))
|
schema.name))
|
||||||
|
sys.exit()
|
||||||
processed_supportmap.append((_support, processed_namemap))
|
processed_supportmap.append((_support, processed_namemap))
|
||||||
operator_schemas.append((domain, processed_supportmap))
|
operator_schemas.append((domain, processed_supportmap))
|
||||||
return operator_schemas
|
return operator_schemas
|
||||||
|
|
Loading…
Reference in New Issue