fix type inference (#144)
* fix type inference for ConstantOp and MaxPoolSingleOut * modify interface * use OpInterface * change name to OpInterface * Builder dependence * Update CMakeLists.txt * Update CMakeLists.txt Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
6099efd91b
commit
5a542e78a0
|
@ -27,7 +27,7 @@ target_include_directories(OMBuilder
|
|||
# variable definitions when building onnx such as -DONNX_ML=1 -DONNX_NAMESPACE=onnx
|
||||
# will NOT be carried over when compiling FrontendDialectHelper.cpp, etc. so
|
||||
# the compilation will fail.
|
||||
add_dependencies(OMBuilder OMONNXOps)
|
||||
add_dependencies(OMBuilder OMONNXOps OMResultTypeInferenceOpInterface)
|
||||
|
||||
if (INCLUDE_ONNX_ML)
|
||||
add_dependencies(OMBuilder OMMLONNXOps)
|
||||
|
|
|
@ -20,6 +20,8 @@
|
|||
#include <mpark/variant.hpp>
|
||||
namespace bstd = mpark;
|
||||
|
||||
#include "src/Interface/ResultTypeInferenceOpInterface.hpp"
|
||||
|
||||
#include "FrontendDialectTransformer.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
@ -296,6 +298,17 @@ private:
|
|||
|
||||
// TODO: Handle optional inputs.
|
||||
auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
|
||||
|
||||
// Type inference for results.
|
||||
if (auto opWithTypeInference =
|
||||
mlir::dyn_cast<mlir::ResultTypeInferenceOpInterface>(
|
||||
op.getOperation())) {
|
||||
auto outTypes = opWithTypeInference.resultTypeInference();
|
||||
for (int i = 0; i < node.output().size(); i++) {
|
||||
(*op.getODSResults(i).begin()).setType(outTypes[i]);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < node.output().size(); i++) {
|
||||
frontend_symbols_.AddMapping(
|
||||
legalize_name(node.output()[i]), *(op.getODSResults(i).begin()));
|
||||
|
|
|
@ -46,6 +46,7 @@ target_link_libraries(onnx-mlir
|
|||
OMShapeInferenceOpInterface
|
||||
OMAttributePromotion
|
||||
OMPromotableConstOperandsOpInterface
|
||||
OMResultTypeInferenceOpInterface
|
||||
OMElideConstants
|
||||
OMElideKrnlGlobalConstants
|
||||
OMKrnlToAffine
|
||||
|
|
|
@ -17,6 +17,7 @@ add_dependencies(OMMLONNXOps OMMLONNXOpsIncGen)
|
|||
# Linking dependencies:
|
||||
add_dependencies(OMMLONNXOps
|
||||
OMPromotableConstOperandsOpInterface
|
||||
OMResultTypeInferenceOpInterface
|
||||
OMShapeInferenceOpInterface)
|
||||
|
||||
add_onnx_mlir_dialect_doc(mlonnx MLONNXOps.td)
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
#include "src/Interface/PromotableConstOperandsOpInterface.hpp"
|
||||
#include "src/Interface/ResultTypeInferenceOpInterface.hpp"
|
||||
#include "src/Interface/ShapeInferenceInterface.hpp"
|
||||
|
||||
namespace mlir {
|
||||
|
|
|
@ -27,6 +27,11 @@ include "src/Interface/ShapeInferenceInterface.td"
|
|||
include "src/Interface/PromotableConstOperandsOpInterface.td"
|
||||
#endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
||||
|
||||
#ifdef RESULT_TYPE_INFERENCE_OP_INTERFACE
|
||||
#else
|
||||
include "src/Interface/ResultTypeInferenceOpInterface.td"
|
||||
#endif // RESULT_TYPE_INFERENCE_OP_INTERFACE
|
||||
|
||||
def MLONNX_Dialect : Dialect {
|
||||
let name = "mlonnx";
|
||||
let cppNamespace = "";
|
||||
|
|
|
@ -18,6 +18,7 @@ add_dependencies(OMONNXOps OMONNXOpsIncGen)
|
|||
# Linking dependencies:
|
||||
add_dependencies(OMONNXOps
|
||||
OMPromotableConstOperandsOpInterface
|
||||
OMResultTypeInferenceOpInterface
|
||||
OMShapeInferenceOpInterface)
|
||||
|
||||
add_onnx_mlir_dialect_doc(onnx ONNXOps.td)
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
#include "src/Interface/PromotableConstOperandsOpInterface.hpp"
|
||||
#include "src/Interface/ResultTypeInferenceOpInterface.hpp"
|
||||
#include "src/Interface/ShapeInferenceInterface.hpp"
|
||||
|
||||
namespace mlir {
|
||||
|
|
|
@ -27,6 +27,11 @@ include "src/Interface/ShapeInferenceInterface.td"
|
|||
include "src/Interface/PromotableConstOperandsOpInterface.td"
|
||||
#endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
||||
|
||||
#ifdef RESULT_TYPE_INFERENCE_OP_INTERFACE
|
||||
#else
|
||||
include "src/Interface/ResultTypeInferenceOpInterface.td"
|
||||
#endif // RESULT_TYPE_INFERENCE_OP_INTERFACE
|
||||
|
||||
def ONNX_Dialect : Dialect {
|
||||
let name = "onnx";
|
||||
let cppNamespace = "";
|
||||
|
@ -120,7 +125,7 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
|
|||
return 1;
|
||||
}
|
||||
static std::vector<int> getTypeMap() {
|
||||
return {0};
|
||||
return {20};
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
@ -156,7 +161,7 @@ def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode",
|
|||
return 1;
|
||||
}
|
||||
static std::vector<int> getTypeMap() {
|
||||
return {0};
|
||||
return {20};
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
|
|
@ -559,7 +559,7 @@ def ONNXConcatFromSequenceOp:ONNX_Op<"ConcatFromSequence",
|
|||
}
|
||||
|
||||
def ONNXConstantOp:ONNX_Op<"Constant",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"ResultTypeInferenceOpInterface">]> {
|
||||
let summary = "ONNX Constant operation";
|
||||
let description = [{
|
||||
"A constant tensor. Exactly one of the two attributes, either value or sparse_value,"
|
||||
|
@ -578,6 +578,15 @@ def ONNXConstantOp:ONNX_Op<"Constant",
|
|||
static std::vector<int> getTypeMap() {
|
||||
return {-1};
|
||||
}
|
||||
std::vector<mlir::Type> resultTypeInference() {
|
||||
std::vector<mlir::Type> resultTypes;
|
||||
if (auto attr = valueAttr()) {
|
||||
resultTypes.push_back(attr.getType());
|
||||
} else if (auto attr = sparse_valueAttr()) {
|
||||
resultTypes.push_back(attr.getType());
|
||||
}
|
||||
return resultTypes;
|
||||
}
|
||||
}];
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, Attribute sparse_value, Attribute value", [{
|
||||
|
@ -589,7 +598,8 @@ def ONNXConstantOp:ONNX_Op<"Constant",
|
|||
build(builder, state, tensorType, sparse_value, value);
|
||||
}
|
||||
}]>
|
||||
];}
|
||||
];
|
||||
}
|
||||
|
||||
def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape",
|
||||
[NoSideEffect]> {
|
||||
|
|
|
@ -25,3 +25,17 @@ target_include_directories(OMShapeInferenceOpInterface
|
|||
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
|
||||
${ONNX_MLIR_SRC_ROOT})
|
||||
add_dependencies(OMShapeInferenceOpInterface ShapeInferenceOpInterfaceIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS ResultTypeInferenceOpInterface.td)
|
||||
onnx_mlir_tablegen(ResultTypeInferenceOpInterface.hpp.inc -gen-op-interface-decls)
|
||||
onnx_mlir_tablegen(ResultTypeInferenceOpInterface.cpp.inc -gen-op-interface-defs)
|
||||
add_public_tablegen_target(OMResultTypeInferenceOpInterfaceIncGen)
|
||||
|
||||
add_library(OMResultTypeInferenceOpInterface
|
||||
ResultTypeInferenceOpInterface.hpp
|
||||
ResultTypeInferenceOpInterface.cpp)
|
||||
target_include_directories(OMResultTypeInferenceOpInterface
|
||||
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
|
||||
${ONNX_MLIR_SRC_ROOT})
|
||||
add_dependencies(OMResultTypeInferenceOpInterface
|
||||
OMResultTypeInferenceOpInterfaceIncGen)
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
//===------------ ResultTypeInferenceOpInterface.cpp --------------===//
|
||||
//===--------- Infer Data Type for Results Interface Definition --------===//
|
||||
//
|
||||
// Copyright 2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains the implementation of the data type inference interface.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "ResultTypeInferenceOpInterface.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Infer Data Type for Results Interface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/Interface/ResultTypeInferenceOpInterface.cpp.inc"
|
|
@ -0,0 +1,25 @@
|
|||
//===------------ ResultTypeInferenceOpInterface.hpp --------------===//
|
||||
//===------- Infer Data Type for Result of Op Interface Definition -------===//
|
||||
//
|
||||
// Copyright 2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains the declaration of the data type reference for op
|
||||
// interface.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// Include the auto-generated declarations.
|
||||
#include "src/Interface/ResultTypeInferenceOpInterface.hpp.inc"
|
||||
|
||||
} // end namespace mlir
|
|
@ -0,0 +1,34 @@
|
|||
//===------------ ResultTypeInferenceOpInterface.td --------------===//
|
||||
//===--------- Infer Data Type for Results Interface Definition --------===//
|
||||
//
|
||||
// Copyright 2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains the tablegen of the data type inference interface.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifdef RESULT_TYPE_INFERENCE_OP_INTERFACE
|
||||
#else
|
||||
#define RESULT_TYPE_INFERENCE_OP_INTERFACE
|
||||
|
||||
#ifdef OP_BASE
|
||||
#else
|
||||
include "mlir/IR/OpBase.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def ResultTypeInferenceOpInterface : OpInterface<"ResultTypeInferenceOpInterface"> {
|
||||
let description = [{
|
||||
Interface to access a registered method to infer the data types for
|
||||
the result of an operation
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<"Infer output data type for this operation class.",
|
||||
"std::vector<mlir::Type>", "resultTypeInference"
|
||||
>
|
||||
];
|
||||
}
|
||||
|
||||
#endif // RESULT_TYPE_INFERENCE_OP_INTERFACE
|
|
@ -271,6 +271,17 @@ OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv']
|
|||
OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)],
|
||||
"Pad": [("pads", 1), ("constant_value", 2)]}
|
||||
|
||||
# Interface for special handling of type inference
|
||||
# The common code are put into get_type_inference_func
|
||||
OpsWithResultTypeInference = {
|
||||
"Constant":
|
||||
'''if (auto attr = valueAttr()) {
|
||||
resultTypes.push_back(attr.getType());
|
||||
} else if (auto attr = sparse_valueAttr()) {
|
||||
resultTypes.push_back(attr.getType());
|
||||
}'''
|
||||
}
|
||||
|
||||
# Add an Op in this list if the Op needs result type deduction which is required
|
||||
# when writing declarative rewriting rules. Deduced type is always
|
||||
# an UnrankedTensorType whose element type is the same as the first operand's
|
||||
|
@ -634,6 +645,23 @@ def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx):
|
|||
|
||||
return s
|
||||
|
||||
def get_type_inference_func(s, indent, type_inference_code):
|
||||
indent = inc_indent(indent)
|
||||
|
||||
s += indent + "std::vector<mlir::Type> resultTypeInference() {" + "\n"
|
||||
indent = inc_indent(indent)
|
||||
s += indent + "std::vector<mlir::Type> resultTypes;" + "\n"
|
||||
|
||||
s += indent + type_inference_code + '\n'
|
||||
|
||||
s += indent + "return resultTypes;" + "\n"
|
||||
indent = dec_indent(indent)
|
||||
s += indent + "}" + "\n"
|
||||
|
||||
indent = dec_indent(indent)
|
||||
return s
|
||||
|
||||
|
||||
|
||||
def gen_op_def(schema):
|
||||
indent = inc_indent()
|
||||
|
@ -648,6 +676,8 @@ def gen_op_def(schema):
|
|||
traits.append("DeclareOpInterfaceMethods<ShapeInferenceOpInterface>")
|
||||
if schema.name in OpsWithPromotableConstOperands.keys():
|
||||
traits.append("OpInterface<\"PromotableConstOperandsOpInterface\">")
|
||||
if schema.name in OpsWithResultTypeInference.keys():
|
||||
traits.append("OpInterface<\"ResultTypeInferenceOpInterface\">")
|
||||
s += inc_indent(indent) + '[{}]> {{\n'.format(join_args(traits))
|
||||
|
||||
# Generate decl for canonicalizer.
|
||||
|
@ -739,10 +769,14 @@ def gen_op_def(schema):
|
|||
s = get_promotable_const_operands_func(
|
||||
s, indent, OpsWithPromotableConstOperands[schema.name])
|
||||
|
||||
if schema.name in OpsWithResultTypeInference:
|
||||
s = get_type_inference_func(
|
||||
s, indent, OpsWithResultTypeInference[schema.name])
|
||||
|
||||
s += indent + '}];\n'
|
||||
|
||||
if ( schema.name in custom_definition_misc) :
|
||||
s += custom_definition_misc[schema.name]
|
||||
s += custom_definition_misc[schema.name] + '\n'
|
||||
|
||||
s += '}\n\n'
|
||||
return s
|
||||
|
@ -852,6 +886,7 @@ def build_operator_schemas():
|
|||
print("Your onnx may be too old."
|
||||
"right version for opertion {} not found".format(
|
||||
schema.name))
|
||||
sys.exit()
|
||||
processed_supportmap.append((_support, processed_namemap))
|
||||
operator_schemas.append((domain, processed_supportmap))
|
||||
return operator_schemas
|
||||
|
|
Loading…
Reference in New Issue