Add result type inference to op definition (#87)

* Add result type inference to op definition

* Edit MLIR tests

* Fix result type for Mul

* Format comments

* Return UnrankedTensorType as result type

* Just for testing -split-input-file

* Undo: Just for testing -split-input-file

* Extract a function, get_operand_ins, that gets operand types; rewrite gen_attr_ins function

* Generate custom builders

* Call existing build methods

* Add comments

* Minor changes

* Generate build methods with attributes

* Add support of variadic type

* Do not generate custom build methods for ops having only attributes

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Tung D. Le 2020-02-21 23:28:24 +09:00 committed by GitHub
parent aea6479ad3
commit 479dd5e35a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 262 additions and 292 deletions

View File

@ -54,6 +54,15 @@ ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
CanonicalList=['Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum', CanonicalList=['Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum',
'ReduceLogSumExp', 'ReduceSumSquare'] 'ReduceLogSumExp', 'ReduceSumSquare']
#add an Op in this list if the Op needs result type deduction which is required
#when writing declarative rewriting rules. Deduced type is always
#an UnrankedTensorType whose element type is the same as the first operand's
#element type.
#currenlty, there are only two build methods generated:
# - one with operands and attributes having a separate parameter, and
# - one with operands and attributes having aggregated parameters.
custom_builder_ops_list = ['Abs', 'Mul', 'Exp', 'ReduceSum', 'ReduceSumSquare']
manual_code_in_op_def = dict([ manual_code_in_op_def = dict([
('DummyExample', ' let extraClassDeclaration = [{ \n'+ ('DummyExample', ' let extraClassDeclaration = [{ \n'+
' static StringRef getPermAttrName() { return "perm"; }\n'+ ' static StringRef getPermAttrName() { return "perm"; }\n'+
@ -345,38 +354,23 @@ def gen_schema(schema) :
#input #input
s+= '\n'+line_indent+'let arguments = (ins ' s+= '\n'+line_indent+'let arguments = (ins '
isfirst = True isfirst = True
if schema.inputs: # add operands
isfirst = False operand_ins = get_operand_ins(schema)
for input in schema.inputs: for operand_type, operand_name in operand_ins:
if input != schema.inputs[0] : if not isfirst:
s+= ',\n ' s+= ',\n '
etypes=collect_types(schema, input) else:
isfirst = False
s+=operand_type+':$'+operand_name
if OpSchema.FormalParameterOption.Optional == input.option: # add attributes
#TODO: handle optional attr_ins = get_attr_ins(schema)
print("warning: optional input for"+schema.name+' '+input.name) for attr_type, attr_name in attr_ins:
elif OpSchema.FormalParameterOption.Variadic == input.option: if not isfirst:
if input.isHomogeneous: s += ',\n '
s+= 'Variadic<'
else : else :
#TODO handle (variadic, heterogeneous)" isfirst = False
print("warning: (variadic, heterogeneous) for"+schema.name+' '+input.name) s += attr_type+':$'+attr_name
if etypes == '':
s+= 'AnyTypeOf<[AnyMemRef, AnyTensor]>'
else:
s+= 'TensorOf<['+etypes+']>'
if OpSchema.FormalParameterOption.Optional == input.option:
#TODO: handle optional
t=''
elif OpSchema.FormalParameterOption.Variadic == input.option:
if input.isHomogeneous:
s+= '>'
else:
#TODO handle (variadic, heterogeneous)"
t=''
s+=':$'+input.name
s += gen_attr_ins(schema, isfirst)
s+= ');' s+= ');'
#output #output
@ -395,6 +389,71 @@ def gen_schema(schema) :
s+= ');\n' s+= ');\n'
#s+= 'let hasCanonicalizer = 1;' #s+= 'let hasCanonicalizer = 1;'
#TODO: any better way to do this.
def get_attr_type_for_builder(attr_type) :
if 'I64Attr' in attr_type :
mytype = 'IntegerAttr'
elif 'F32Attr' in attr_type :
mytype = 'FloatAttr'
elif 'I64ArrayAttr' in attr_type or 'F32ArrayAttr' in attr_type:
mytype = 'ArrayAttr'
elif 'StrAttr' in attr_type :
mytype = 'StringAttr'
elif 'strings' in attr_type :
mytype = 'ArrayAttr'
else :
mytype ='Attribute'
return mytype
def get_op_type_for_builder(op_type):
if op_type.startswith('Variadic'):
mytype = 'ValueRange'
else:
mytype = 'Value'
return mytype
# add custom builders
# use element type of the first operand to construct an UnrankedTensorType for the output.
if schema.name in custom_builder_ops_list:
if len(operand_ins) == 0:
print("warning: not generate custom build methods for " + schema.name + " since it does not have operands.")
else:
if get_op_type_for_builder(operand_ins[0][0]) == 'ValueRange':
first_operand = operand_ins[0][1]+'[0]'
else:
first_operand = operand_ins[0][1]
s += line_indent+'let builders = [\n'
# custom builders with operands and attributes having a seperate parameter.
# E.g. OpBuilder<"Builder *builder, OperationState &state, Value X, Value, Y, Attribute A", [{}]>
s += line_indent*2+'OpBuilder<"Builder *builder, OperationState &state'
for arg_type, arg_name in operand_ins:
s += ', '+get_op_type_for_builder(arg_type)+' '+arg_name
for attr_type, attr_name in attr_ins:
s += ', '+get_attr_type_for_builder(attr_type)+' '+attr_name
s += '", [{\n'
s += line_indent*3+'auto elementType = '+first_operand+'.getType().cast<TensorType>().getElementType();\n'
s += line_indent*3+'build(builder, state, UnrankedTensorType::get(elementType)'
for _, arg_name in operand_ins:
s += ', '+arg_name
for _, attr_name in attr_ins:
s += ', '+attr_name
s += ');\n'
s += line_indent*2+'}]>,\n'
# custom builders with all operands and attributes having aggregate parameters.
# E.g. OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{}]>'
s += line_indent*2+'OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{\n'
s += line_indent*3+'auto elementType = '+first_operand+'.getType().cast<TensorType>().getElementType();\n'
s += line_indent*3+'std::vector<mlir::Type> outputTypes;\n'
s += line_indent*3+'outputTypes.emplace_back(UnrankedTensorType::get(elementType));\n'
s += line_indent*3+'build(builder, state, outputTypes, operands, attributes);\n'
s += line_indent*2+'}]>'
s += '\n'+line_indent+'];\n'
#add special code #add special code
if schema.name in manual_code_in_op_def : if schema.name in manual_code_in_op_def :
s += manual_code_in_op_def[schema.name] s += manual_code_in_op_def[schema.name]
@ -447,7 +506,41 @@ def gen_code(schema,fefile) :
else: else:
fefile.write(', '+variadicIn+', '+variadicOut+');\n') fefile.write(', '+variadicIn+', '+variadicOut+');\n')
def gen_attr_ins(schema, isfirst) : def get_operand_ins(schema):
operand_type_and_name_list = [] # [(optype, opname)]
if schema.inputs:
for input in schema.inputs:
optype = ""
etypes=collect_types(schema, input)
if OpSchema.FormalParameterOption.Optional == input.option:
#TODO : handle optional
print("warning: optional input for"+schema.name+' '+input.name)
elif OpSchema.FormalParameterOption.Variadic == input.option:
if input.isHomogeneous:
optype += 'Variadic<'
else:
#TODO handle(variadic, heterogeneous) "
print("warning: (variadic, heterogeneous) for"+schema.name+' '+input.name)
if etypes == '':
optype += 'AnyTypeOf<[AnyMemRef, AnyTensor]>'
else:
optype += 'TensorOf<['+etypes+']>'
if OpSchema.FormalParameterOption.Optional == input.option:
#TODO : handle optional
t=''
elif OpSchema.FormalParameterOption.Variadic == input.option:
if input.isHomogeneous:
optype += '>'
else:
#TODO handle(variadic, heterogeneous) "
t=''
operand_type_and_name_list.append((optype, input.name))
return operand_type_and_name_list
def get_attr_ins(schema) :
def get_attr_type_basic(attr_type) : def get_attr_type_basic(attr_type) :
if attr_type == 'int' : if attr_type == 'int' :
@ -479,24 +572,22 @@ def gen_attr_ins(schema, isfirst) :
mytype += ', "'+attr_default+'">' mytype += ', "'+attr_default+'">'
return mytype return mytype
attr_type_and_name_list = [] # :: [(attrtype, attrname)]
attr_line = '' attr_line = ''
if schema.attributes: if schema.attributes:
for _, attr in sorted(schema.attributes.items()): for _, attr in sorted(schema.attributes.items()):
#attr_line = line_indent+line_indent+line_indent+line_indent #attr_line = line_indent+line_indent+line_indent+line_indent
if not isfirst: found = False
attr_line += ',\n ' attr_type = ""
else :
isfirst = False
if schema.name+' '+attr.name in special_attr_defaults: if schema.name+' '+attr.name in special_attr_defaults:
(attr_type_str, attr_default_str) = special_attr_defaults[schema.name+' '+attr.name] (attr_type_str, attr_default_str) = special_attr_defaults[schema.name+' '+attr.name]
attr_line += get_attr_type_with_default(attr_type_str, attr_default_str) attr_type = get_attr_type_with_default(attr_type_str, attr_default_str)
attr_line += ':$'+attr.name found = True
elif attr.required: elif attr.required:
s = Text(attr.type) s = Text(attr.type)
attr_type_str = s[s.rfind('.') + 1:].lower() attr_type_str = s[s.rfind('.') + 1:].lower()
attr_line += get_attr_type_basic(attr_type_str) attr_type = get_attr_type_basic(attr_type_str)
attr_line += ':$'+attr.name found = True
# option holds either required or default value # option holds either required or default value
elif attr.default_value.name: elif attr.default_value.name:
@ -527,14 +618,15 @@ def gen_attr_ins(schema, isfirst) :
else: else:
default_value = format_value(default_value) default_value = format_value(default_value)
attr_option_str = default_value attr_option_str = default_value
attr_line += get_attr_type_with_default(attr_type_str, attr_option_str) attr_type = get_attr_type_with_default(attr_type_str, attr_option_str)
attr_line += ':$'+attr.name found = True
else: else:
s = Text(attr.type) s = Text(attr.type)
attr_type_str = s[s.rfind('.') + 1:].lower() attr_type_str = s[s.rfind('.') + 1:].lower()
attr_line += get_attr_type_optional(attr_type_str) attr_type = get_attr_type_optional(attr_type_str)
attr_line += ':$'+attr.name if found:
return attr_line attr_type_and_name_list.append((attr_type, attr.name))
return attr_type_and_name_list
def main(args): # type: (Type[Args]) -> None def main(args): # type: (Type[Args]) -> None
with io.open(args.changelog, 'w', newline='') as fout: with io.open(args.changelog, 'w', newline='') as fout:

View File

@ -36,6 +36,11 @@ onnf_tablegen(onnx_combine.inc -gen-rewriters)
add_public_tablegen_target(gen_onnx_combine) add_public_tablegen_target(gen_onnx_combine)
add_dependencies(compiler gen_onnx_combine) add_dependencies(compiler gen_onnx_combine)
set(LLVM_TARGET_DEFINITIONS pass/onnx_rewrite.td)
onnf_tablegen(onnx_rewrite.inc -gen-rewriters)
add_public_tablegen_target(gen_onnx_rewrite)
add_dependencies(compiler gen_onnx_rewrite)
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td) set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass") onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass") onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")

View File

@ -14,6 +14,18 @@ def ONNXAbsOp:ONNX_Op<"Abs",
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let builders = [
OpBuilder<"Builder *builder, OperationState &state, Value X", [{
auto elementType = X.getType().cast<TensorType>().getElementType();
build(builder, state, UnrankedTensorType::get(elementType), X);
}]>,
OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
std::vector<mlir::Type> outputTypes;
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
build(builder, state, outputTypes, operands, attributes);
}]>
];
} }
def ONNXAcosOp:ONNX_Op<"Acos", def ONNXAcosOp:ONNX_Op<"Acos",
@ -649,6 +661,18 @@ def ONNXExpOp:ONNX_Op<"Exp",
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
let builders = [
OpBuilder<"Builder *builder, OperationState &state, Value input", [{
auto elementType = input.getType().cast<TensorType>().getElementType();
build(builder, state, UnrankedTensorType::get(elementType), input);
}]>,
OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
std::vector<mlir::Type> outputTypes;
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
build(builder, state, outputTypes, operands, attributes);
}]>
];
} }
def ONNXExpandOp:ONNX_Op<"Expand", def ONNXExpandOp:ONNX_Op<"Expand",
@ -1763,6 +1787,18 @@ def ONNXMulOp:ONNX_Op<"Mul",
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); AnyTypeOf<[AnyMemRef, AnyTensor]>:$B);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$C); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$C);
let builders = [
OpBuilder<"Builder *builder, OperationState &state, Value A, Value B", [{
auto elementType = A.getType().cast<TensorType>().getElementType();
build(builder, state, UnrankedTensorType::get(elementType), A, B);
}]>,
OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
std::vector<mlir::Type> outputTypes;
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
build(builder, state, outputTypes, operands, attributes);
}]>
];
} }
def ONNXMultinomialOp:ONNX_Op<"Multinomial", def ONNXMultinomialOp:ONNX_Op<"Multinomial",
@ -2431,6 +2467,18 @@ def ONNXReduceSumOp:ONNX_Op<"ReduceSum",
OptionalAttr<I64ArrayAttr>:$axes, OptionalAttr<I64ArrayAttr>:$axes,
DefaultValuedAttr<I64Attr, "1">:$keepdims); DefaultValuedAttr<I64Attr, "1">:$keepdims);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reduced); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reduced);
let builders = [
OpBuilder<"Builder *builder, OperationState &state, Value data, ArrayAttr axes, IntegerAttr keepdims", [{
auto elementType = data.getType().cast<TensorType>().getElementType();
build(builder, state, UnrankedTensorType::get(elementType), data, axes, keepdims);
}]>,
OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
std::vector<mlir::Type> outputTypes;
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
build(builder, state, outputTypes, operands, attributes);
}]>
];
} }
def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare", def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare",
@ -2449,6 +2497,18 @@ def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare",
OptionalAttr<I64ArrayAttr>:$axes, OptionalAttr<I64ArrayAttr>:$axes,
DefaultValuedAttr<I64Attr, "1">:$keepdims); DefaultValuedAttr<I64Attr, "1">:$keepdims);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reduced); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reduced);
let builders = [
OpBuilder<"Builder *builder, OperationState &state, Value data, ArrayAttr axes, IntegerAttr keepdims", [{
auto elementType = data.getType().cast<TensorType>().getElementType();
build(builder, state, UnrankedTensorType::get(elementType), data, axes, keepdims);
}]>,
OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
std::vector<mlir::Type> outputTypes;
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
build(builder, state, outputTypes, operands, attributes);
}]>
];
} }
def ONNXReluOp:ONNX_Op<"Relu", def ONNXReluOp:ONNX_Op<"Relu",

View File

@ -17,252 +17,8 @@
using namespace mlir; using namespace mlir;
namespace { namespace {
/// Include the patterns defined in the Declarative Rewrite framework.
// There are two ways to write rewrite rules: #include "src/onnx_rewrite.inc"
// - Declarative manner: specify rewrite rules in a TableGen record, and
// - Manual Manner: subclass the mlir::RewritePattern.
//
// We prefer to use the former way as much as possible. However, there is a
// limitation about operation definition specification (ODS) in TableGen that
// requires us to write custom builders, that is
// "all ODS-generated `build()` methods require specifying the result type(s),
// unless the op has known traits like `SameOperandsAndResultType` that we can
// use to auto-generate a `build()` method with result type deduction".
//
// More information about the limitation can be found here:
// https://github.com/llvm/llvm-project/blob/master/mlir/docs/DeclarativeRewrites.md#building-operations
//
// Currently, we use the latter way of writing rewrite rules. There are two
// reasons for this decision:
// - To insert custom builders for operations, it is better to change the script
// gen_doc.py to generate all possibles custom builders for a large class of
// operations. At the time of this patch created, the gen_doc.py was changing,
// so we decided to write manually to reduce conflicts.
// - In declarative rewriting, we should deal with optional attributes. E.g. for
// to handle optional attributes, but I haven't tried it yet.
//
// Once we have done the above issues, we will switch to use the declarative
// manner.
//===----------------------------------------------------------------------===//
// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X)
//===----------------------------------------------------------------------===//
struct ReduceL1OpPattern : public RewritePattern {
ReduceL1OpPattern(MLIRContext *context)
: RewritePattern(ONNXReduceL1Op::getOperationName(),
{ONNXAbsOp::getOperationName(),
ONNXReduceSumOp::getOperationName()},
1, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto opInput = op->getOperands()[0]; // %X
auto opResults = op->getResults();
auto opAttrs = op->getAttrs();
// Rewrite
ONNXAbsOp absOp;
{
auto elementType = opInput.getType().cast<TensorType>().getElementType();
absOp = rewriter.create<ONNXAbsOp>(
loc, UnrankedTensorType::get(elementType), opInput);
}
ONNXReduceSumOp sumOp;
{
SmallVector<Type, 4> types;
for (auto v : opResults) {
types.emplace_back(v.getType());
}
SmallVector<Value, 1> values;
values.emplace_back(absOp.getResult());
SmallVector<NamedAttribute, 4> attrs;
for (auto attr : opAttrs) {
attrs.emplace_back(attr);
}
sumOp = rewriter.create<ONNXReduceSumOp>(loc, types, values, attrs);
}
rewriter.replaceOp(op, sumOp.getResult());
return matchSuccess();
};
};
//===----------------------------------------------------------------------===//
// ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X))
//===----------------------------------------------------------------------===//
struct ReduceL2OpPattern : public RewritePattern {
ReduceL2OpPattern(MLIRContext *context)
: RewritePattern(ONNXReduceL2Op::getOperationName(),
{ONNXSqrtOp::getOperationName(),
ONNXReduceSumSquareOp::getOperationName()},
1, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto opInput = op->getOperands()[0]; // %X
auto opResults = op->getResults();
auto opAttrs = op->getAttrs();
// Rewrite
ONNXReduceSumSquareOp sumSquareOp;
{
auto elementType = opInput.getType().cast<TensorType>().getElementType();
sumSquareOp = rewriter.create<ONNXReduceSumSquareOp>(
loc, UnrankedTensorType::get(elementType), opInput, opAttrs);
}
ONNXSqrtOp sqrtOp;
{
SmallVector<Type, 4> types;
for (auto v : opResults) {
types.emplace_back(v.getType());
}
sqrtOp = rewriter.create<ONNXSqrtOp>(loc, types, sumSquareOp.getResult());
}
rewriter.replaceOp(op, sqrtOp.getResult());
return matchSuccess();
};
};
//===----------------------------------------------------------------------===//
// ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X))
//===----------------------------------------------------------------------===//
struct ReduceLogSumOpPattern : public RewritePattern {
ReduceLogSumOpPattern(MLIRContext *context)
: RewritePattern(ONNXReduceLogSumOp::getOperationName(),
{ONNXReduceSumOp::getOperationName(),
ONNXLogOp::getOperationName()},
1, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto opInput = op->getOperands()[0]; // %X
auto opResults = op->getResults();
auto opAttrs = op->getAttrs();
// Rewrite
ONNXReduceSumOp sumOp;
{
auto elementType = opInput.getType().cast<TensorType>().getElementType();
sumOp = rewriter.create<ONNXReduceSumOp>(
loc, UnrankedTensorType::get(elementType), opInput, opAttrs);
}
ONNXLogOp logOp;
{
SmallVector<Type, 4> types;
for (auto v : opResults) {
types.emplace_back(v.getType());
}
logOp = rewriter.create<ONNXLogOp>(loc, types, sumOp.getResult());
}
rewriter.replaceOp(op, logOp.getResult());
return matchSuccess();
};
};
//===----------------------------------------------------------------------===//
// ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X)
//===----------------------------------------------------------------------===//
struct ReduceLogSumExpOpPattern : public RewritePattern {
ReduceLogSumExpOpPattern(MLIRContext *context)
: RewritePattern(ONNXReduceLogSumExpOp::getOperationName(),
{ONNXExpOp::getOperationName(),
ONNXReduceLogSumOp::getOperationName()},
1, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto opInput = op->getOperands()[0]; // %X
auto opResults = op->getResults();
auto opAttrs = op->getAttrs();
// Rewrite
ONNXExpOp expOp;
{
auto elementType = opInput.getType().cast<TensorType>().getElementType();
expOp = rewriter.create<ONNXExpOp>(
loc, UnrankedTensorType::get(elementType), opInput);
}
ONNXReduceLogSumOp logSumOp;
{
SmallVector<Type, 4> types;
for (auto v : opResults) {
types.emplace_back(v.getType());
}
SmallVector<Value, 1> values;
values.emplace_back(expOp.getResult());
SmallVector<NamedAttribute, 4> attrs;
for (auto attr : opAttrs) {
attrs.emplace_back(attr);
}
logSumOp = rewriter.create<ONNXReduceLogSumOp>(loc, types, values, attrs);
}
rewriter.replaceOp(op, logSumOp.getResult());
return matchSuccess();
};
};
//===----------------------------------------------------------------------===//
// ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X)
//===----------------------------------------------------------------------===//
struct ReduceSumSquareOpPattern : public RewritePattern {
ReduceSumSquareOpPattern(MLIRContext *context)
: RewritePattern(ONNXReduceSumSquareOp::getOperationName(),
{ONNXMulOp::getOperationName(),
ONNXReduceSumOp::getOperationName()},
1, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto opInput = op->getOperands()[0]; // %X
auto opResults = op->getResults();
auto opAttrs = op->getAttrs();
// Rewrite
ONNXMulOp mulOp;
{
auto elementType = opInput.getType().cast<TensorType>().getElementType();
mulOp = rewriter.create<ONNXMulOp>(
loc, UnrankedTensorType::get(elementType), opInput, opInput);
}
ONNXReduceSumOp sumOp;
{
SmallVector<Type, 4> types;
for (auto v : opResults) {
types.emplace_back(v.getType());
}
SmallVector<Value, 1> values;
values.emplace_back(mulOp.getResult());
SmallVector<NamedAttribute, 4> attrs;
for (auto attr : opAttrs) {
attrs.emplace_back(attr);
}
sumOp = rewriter.create<ONNXReduceSumOp>(loc, types, values, attrs);
}
rewriter.replaceOp(op, sumOp.getResult());
return matchSuccess();
};
};
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Rewrite: // Rewrite:

57
src/pass/onnx_rewrite.td Normal file
View File

@ -0,0 +1,57 @@
//===----------------------------------------------------------------------===//
//=- onnx_rewrite.td - Pattern Match Rewriting for ONNX -*- tablegen -*----===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// Defines language-specific pattern match optimizations for ONNX using
// Declarative Rewrite Rules (DRR) specified using TableGen records.
//
#ifndef ONNX_REWRITE
#define ONNX_REWRITE
#ifndef OP_BASE
include "dialect/onnx/onnx.td"
#endif // OP_BASE
/// Note: The DRR definition used for defining patterns is shown below:
///
/// class Pattern<
/// dag sourcePattern, list<dag> resultPatterns,
/// list<dag> additionalConstraints = [],
/// dag benefitsAdded = (addBenefit 0)
/// >;
//===----------------------------------------------------------------------===//
// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X)
//===----------------------------------------------------------------------===//
def ReduceL1OpPattern: Pat<(ONNXReduceL1Op $oprd, $axes, $keepdims),
(ONNXReduceSumOp (ONNXAbsOp $oprd), $axes, $keepdims)>;
//===----------------------------------------------------------------------===//
// ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X))
//===----------------------------------------------------------------------===//
def ReduceL2OpPattern: Pat<(ONNXReduceL2Op $oprd, $axes, $keepdims),
(ONNXSqrtOp (ONNXReduceSumSquareOp $oprd, $axes, $keepdims))>;
//===----------------------------------------------------------------------===//
// ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X))
//===----------------------------------------------------------------------===//
def ReduceLogSumOpPattern: Pat<(ONNXReduceLogSumOp $oprd, $axes, $keepdims),
(ONNXLogOp (ONNXReduceSumOp $oprd, $axes, $keepdims))>;
//===----------------------------------------------------------------------===//
// ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X)
//===----------------------------------------------------------------------===//
def ReduceLogSumExpOpPattern: Pat<(ONNXReduceLogSumExpOp $oprd, $axes, $keepdims),
(ONNXReduceLogSumOp (ONNXExpOp $oprd), $axes, $keepdims)>;
//===----------------------------------------------------------------------===//
// ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X)
//===----------------------------------------------------------------------===//
def ReduceSumSquareOpPattern: Pat<(ONNXReduceSumSquareOp $oprd, $axes, $keepdims),
(ONNXReduceSumOp (ONNXMulOp $oprd, $oprd), $axes, $keepdims)>;
#endif // ONNX_REWRITE