Add result type inference to op definition (#87)
* Add result type inference to op definition * Edit MLIR tests * Fix result type for Mul * Format comments * Return UnrankedTensorType as result type * Just for testing -split-input-file * Undo: Just for testing -split-input-file * Extract a function, get_operand_ins, that gets operand types; rewrite gen_attr_ins function * Generate custom builders * Call existing build methods * Add comments * Minor changes * Generate build methods with attributes * Add support of variadic type * Do not generate custom build methods for ops having only attributes Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
aea6479ad3
commit
479dd5e35a
184
doc/gen_doc.py
184
doc/gen_doc.py
|
@ -54,6 +54,15 @@ ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
|
|||
CanonicalList=['Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum',
|
||||
'ReduceLogSumExp', 'ReduceSumSquare']
|
||||
|
||||
#add an Op in this list if the Op needs result type deduction which is required
|
||||
#when writing declarative rewriting rules. Deduced type is always
|
||||
#an UnrankedTensorType whose element type is the same as the first operand's
|
||||
#element type.
|
||||
#currenlty, there are only two build methods generated:
|
||||
# - one with operands and attributes having a separate parameter, and
|
||||
# - one with operands and attributes having aggregated parameters.
|
||||
custom_builder_ops_list = ['Abs', 'Mul', 'Exp', 'ReduceSum', 'ReduceSumSquare']
|
||||
|
||||
manual_code_in_op_def = dict([
|
||||
('DummyExample', ' let extraClassDeclaration = [{ \n'+
|
||||
' static StringRef getPermAttrName() { return "perm"; }\n'+
|
||||
|
@ -345,38 +354,23 @@ def gen_schema(schema) :
|
|||
#input
|
||||
s+= '\n'+line_indent+'let arguments = (ins '
|
||||
isfirst = True
|
||||
if schema.inputs:
|
||||
isfirst = False
|
||||
for input in schema.inputs:
|
||||
if input != schema.inputs[0] :
|
||||
s+= ',\n '
|
||||
etypes=collect_types(schema, input)
|
||||
# add operands
|
||||
operand_ins = get_operand_ins(schema)
|
||||
for operand_type, operand_name in operand_ins:
|
||||
if not isfirst:
|
||||
s+= ',\n '
|
||||
else:
|
||||
isfirst = False
|
||||
s+=operand_type+':$'+operand_name
|
||||
|
||||
if OpSchema.FormalParameterOption.Optional == input.option:
|
||||
#TODO: handle optional
|
||||
print("warning: optional input for"+schema.name+' '+input.name)
|
||||
elif OpSchema.FormalParameterOption.Variadic == input.option:
|
||||
if input.isHomogeneous:
|
||||
s+= 'Variadic<'
|
||||
else:
|
||||
#TODO handle (variadic, heterogeneous)"
|
||||
print("warning: (variadic, heterogeneous) for"+schema.name+' '+input.name)
|
||||
if etypes == '':
|
||||
s+= 'AnyTypeOf<[AnyMemRef, AnyTensor]>'
|
||||
else:
|
||||
s+= 'TensorOf<['+etypes+']>'
|
||||
|
||||
if OpSchema.FormalParameterOption.Optional == input.option:
|
||||
#TODO: handle optional
|
||||
t=''
|
||||
elif OpSchema.FormalParameterOption.Variadic == input.option:
|
||||
if input.isHomogeneous:
|
||||
s+= '>'
|
||||
else:
|
||||
#TODO handle (variadic, heterogeneous)"
|
||||
t=''
|
||||
s+=':$'+input.name
|
||||
s += gen_attr_ins(schema, isfirst)
|
||||
# add attributes
|
||||
attr_ins = get_attr_ins(schema)
|
||||
for attr_type, attr_name in attr_ins:
|
||||
if not isfirst:
|
||||
s += ',\n '
|
||||
else :
|
||||
isfirst = False
|
||||
s += attr_type+':$'+attr_name
|
||||
s+= ');'
|
||||
|
||||
#output
|
||||
|
@ -395,6 +389,71 @@ def gen_schema(schema) :
|
|||
s+= ');\n'
|
||||
|
||||
#s+= 'let hasCanonicalizer = 1;'
|
||||
|
||||
#TODO: any better way to do this.
|
||||
def get_attr_type_for_builder(attr_type) :
|
||||
if 'I64Attr' in attr_type :
|
||||
mytype = 'IntegerAttr'
|
||||
elif 'F32Attr' in attr_type :
|
||||
mytype = 'FloatAttr'
|
||||
elif 'I64ArrayAttr' in attr_type or 'F32ArrayAttr' in attr_type:
|
||||
mytype = 'ArrayAttr'
|
||||
elif 'StrAttr' in attr_type :
|
||||
mytype = 'StringAttr'
|
||||
elif 'strings' in attr_type :
|
||||
mytype = 'ArrayAttr'
|
||||
else :
|
||||
mytype ='Attribute'
|
||||
return mytype
|
||||
|
||||
def get_op_type_for_builder(op_type):
|
||||
if op_type.startswith('Variadic'):
|
||||
mytype = 'ValueRange'
|
||||
else:
|
||||
mytype = 'Value'
|
||||
return mytype
|
||||
|
||||
# add custom builders
|
||||
# use element type of the first operand to construct an UnrankedTensorType for the output.
|
||||
if schema.name in custom_builder_ops_list:
|
||||
if len(operand_ins) == 0:
|
||||
print("warning: not generate custom build methods for " + schema.name + " since it does not have operands.")
|
||||
else:
|
||||
if get_op_type_for_builder(operand_ins[0][0]) == 'ValueRange':
|
||||
first_operand = operand_ins[0][1]+'[0]'
|
||||
else:
|
||||
first_operand = operand_ins[0][1]
|
||||
|
||||
s += line_indent+'let builders = [\n'
|
||||
|
||||
# custom builders with operands and attributes having a seperate parameter.
|
||||
# E.g. OpBuilder<"Builder *builder, OperationState &state, Value X, Value, Y, Attribute A", [{}]>
|
||||
s += line_indent*2+'OpBuilder<"Builder *builder, OperationState &state'
|
||||
for arg_type, arg_name in operand_ins:
|
||||
s += ', '+get_op_type_for_builder(arg_type)+' '+arg_name
|
||||
for attr_type, attr_name in attr_ins:
|
||||
s += ', '+get_attr_type_for_builder(attr_type)+' '+attr_name
|
||||
s += '", [{\n'
|
||||
s += line_indent*3+'auto elementType = '+first_operand+'.getType().cast<TensorType>().getElementType();\n'
|
||||
s += line_indent*3+'build(builder, state, UnrankedTensorType::get(elementType)'
|
||||
for _, arg_name in operand_ins:
|
||||
s += ', '+arg_name
|
||||
for _, attr_name in attr_ins:
|
||||
s += ', '+attr_name
|
||||
s += ');\n'
|
||||
s += line_indent*2+'}]>,\n'
|
||||
|
||||
# custom builders with all operands and attributes having aggregate parameters.
|
||||
# E.g. OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{}]>'
|
||||
s += line_indent*2+'OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{\n'
|
||||
s += line_indent*3+'auto elementType = '+first_operand+'.getType().cast<TensorType>().getElementType();\n'
|
||||
s += line_indent*3+'std::vector<mlir::Type> outputTypes;\n'
|
||||
s += line_indent*3+'outputTypes.emplace_back(UnrankedTensorType::get(elementType));\n'
|
||||
s += line_indent*3+'build(builder, state, outputTypes, operands, attributes);\n'
|
||||
s += line_indent*2+'}]>'
|
||||
|
||||
s += '\n'+line_indent+'];\n'
|
||||
|
||||
#add special code
|
||||
if schema.name in manual_code_in_op_def :
|
||||
s += manual_code_in_op_def[schema.name]
|
||||
|
@ -447,7 +506,41 @@ def gen_code(schema,fefile) :
|
|||
else:
|
||||
fefile.write(', '+variadicIn+', '+variadicOut+');\n')
|
||||
|
||||
def gen_attr_ins(schema, isfirst) :
|
||||
def get_operand_ins(schema):
|
||||
operand_type_and_name_list = [] # [(optype, opname)]
|
||||
if schema.inputs:
|
||||
for input in schema.inputs:
|
||||
optype = ""
|
||||
|
||||
etypes=collect_types(schema, input)
|
||||
|
||||
if OpSchema.FormalParameterOption.Optional == input.option:
|
||||
#TODO : handle optional
|
||||
print("warning: optional input for"+schema.name+' '+input.name)
|
||||
elif OpSchema.FormalParameterOption.Variadic == input.option:
|
||||
if input.isHomogeneous:
|
||||
optype += 'Variadic<'
|
||||
else:
|
||||
#TODO handle(variadic, heterogeneous) "
|
||||
print("warning: (variadic, heterogeneous) for"+schema.name+' '+input.name)
|
||||
if etypes == '':
|
||||
optype += 'AnyTypeOf<[AnyMemRef, AnyTensor]>'
|
||||
else:
|
||||
optype += 'TensorOf<['+etypes+']>'
|
||||
|
||||
if OpSchema.FormalParameterOption.Optional == input.option:
|
||||
#TODO : handle optional
|
||||
t=''
|
||||
elif OpSchema.FormalParameterOption.Variadic == input.option:
|
||||
if input.isHomogeneous:
|
||||
optype += '>'
|
||||
else:
|
||||
#TODO handle(variadic, heterogeneous) "
|
||||
t=''
|
||||
operand_type_and_name_list.append((optype, input.name))
|
||||
return operand_type_and_name_list
|
||||
|
||||
def get_attr_ins(schema) :
|
||||
|
||||
def get_attr_type_basic(attr_type) :
|
||||
if attr_type == 'int' :
|
||||
|
@ -479,24 +572,22 @@ def gen_attr_ins(schema, isfirst) :
|
|||
mytype += ', "'+attr_default+'">'
|
||||
return mytype
|
||||
|
||||
attr_type_and_name_list = [] # :: [(attrtype, attrname)]
|
||||
attr_line = ''
|
||||
if schema.attributes:
|
||||
for _, attr in sorted(schema.attributes.items()):
|
||||
#attr_line = line_indent+line_indent+line_indent+line_indent
|
||||
if not isfirst:
|
||||
attr_line += ',\n '
|
||||
else :
|
||||
isfirst = False
|
||||
|
||||
found = False
|
||||
attr_type = ""
|
||||
if schema.name+' '+attr.name in special_attr_defaults:
|
||||
(attr_type_str, attr_default_str) = special_attr_defaults[schema.name+' '+attr.name]
|
||||
attr_line += get_attr_type_with_default(attr_type_str, attr_default_str)
|
||||
attr_line += ':$'+attr.name
|
||||
attr_type = get_attr_type_with_default(attr_type_str, attr_default_str)
|
||||
found = True
|
||||
elif attr.required:
|
||||
s = Text(attr.type)
|
||||
attr_type_str = s[s.rfind('.') + 1:].lower()
|
||||
attr_line += get_attr_type_basic(attr_type_str)
|
||||
attr_line += ':$'+attr.name
|
||||
attr_type = get_attr_type_basic(attr_type_str)
|
||||
found = True
|
||||
|
||||
# option holds either required or default value
|
||||
elif attr.default_value.name:
|
||||
|
@ -527,14 +618,15 @@ def gen_attr_ins(schema, isfirst) :
|
|||
else:
|
||||
default_value = format_value(default_value)
|
||||
attr_option_str = default_value
|
||||
attr_line += get_attr_type_with_default(attr_type_str, attr_option_str)
|
||||
attr_line += ':$'+attr.name
|
||||
attr_type = get_attr_type_with_default(attr_type_str, attr_option_str)
|
||||
found = True
|
||||
else:
|
||||
s = Text(attr.type)
|
||||
attr_type_str = s[s.rfind('.') + 1:].lower()
|
||||
attr_line += get_attr_type_optional(attr_type_str)
|
||||
attr_line += ':$'+attr.name
|
||||
return attr_line
|
||||
attr_type = get_attr_type_optional(attr_type_str)
|
||||
if found:
|
||||
attr_type_and_name_list.append((attr_type, attr.name))
|
||||
return attr_type_and_name_list
|
||||
|
||||
def main(args): # type: (Type[Args]) -> None
|
||||
with io.open(args.changelog, 'w', newline='') as fout:
|
||||
|
|
|
@ -36,6 +36,11 @@ onnf_tablegen(onnx_combine.inc -gen-rewriters)
|
|||
add_public_tablegen_target(gen_onnx_combine)
|
||||
add_dependencies(compiler gen_onnx_combine)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS pass/onnx_rewrite.td)
|
||||
onnf_tablegen(onnx_rewrite.inc -gen-rewriters)
|
||||
add_public_tablegen_target(gen_onnx_rewrite)
|
||||
add_dependencies(compiler gen_onnx_rewrite)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
|
||||
onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||
onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||
|
|
|
@ -14,6 +14,18 @@ def ONNXAbsOp:ONNX_Op<"Abs",
|
|||
}];
|
||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X);
|
||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
||||
let builders = [
|
||||
OpBuilder<"Builder *builder, OperationState &state, Value X", [{
|
||||
auto elementType = X.getType().cast<TensorType>().getElementType();
|
||||
build(builder, state, UnrankedTensorType::get(elementType), X);
|
||||
}]>,
|
||||
OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||
std::vector<mlir::Type> outputTypes;
|
||||
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
|
||||
build(builder, state, outputTypes, operands, attributes);
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
def ONNXAcosOp:ONNX_Op<"Acos",
|
||||
|
@ -649,6 +661,18 @@ def ONNXExpOp:ONNX_Op<"Exp",
|
|||
}];
|
||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input);
|
||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
|
||||
let builders = [
|
||||
OpBuilder<"Builder *builder, OperationState &state, Value input", [{
|
||||
auto elementType = input.getType().cast<TensorType>().getElementType();
|
||||
build(builder, state, UnrankedTensorType::get(elementType), input);
|
||||
}]>,
|
||||
OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||
std::vector<mlir::Type> outputTypes;
|
||||
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
|
||||
build(builder, state, outputTypes, operands, attributes);
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
def ONNXExpandOp:ONNX_Op<"Expand",
|
||||
|
@ -1763,6 +1787,18 @@ def ONNXMulOp:ONNX_Op<"Mul",
|
|||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A,
|
||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B);
|
||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$C);
|
||||
let builders = [
|
||||
OpBuilder<"Builder *builder, OperationState &state, Value A, Value B", [{
|
||||
auto elementType = A.getType().cast<TensorType>().getElementType();
|
||||
build(builder, state, UnrankedTensorType::get(elementType), A, B);
|
||||
}]>,
|
||||
OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||
std::vector<mlir::Type> outputTypes;
|
||||
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
|
||||
build(builder, state, outputTypes, operands, attributes);
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
def ONNXMultinomialOp:ONNX_Op<"Multinomial",
|
||||
|
@ -2431,6 +2467,18 @@ def ONNXReduceSumOp:ONNX_Op<"ReduceSum",
|
|||
OptionalAttr<I64ArrayAttr>:$axes,
|
||||
DefaultValuedAttr<I64Attr, "1">:$keepdims);
|
||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reduced);
|
||||
let builders = [
|
||||
OpBuilder<"Builder *builder, OperationState &state, Value data, ArrayAttr axes, IntegerAttr keepdims", [{
|
||||
auto elementType = data.getType().cast<TensorType>().getElementType();
|
||||
build(builder, state, UnrankedTensorType::get(elementType), data, axes, keepdims);
|
||||
}]>,
|
||||
OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||
std::vector<mlir::Type> outputTypes;
|
||||
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
|
||||
build(builder, state, outputTypes, operands, attributes);
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare",
|
||||
|
@ -2449,6 +2497,18 @@ def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare",
|
|||
OptionalAttr<I64ArrayAttr>:$axes,
|
||||
DefaultValuedAttr<I64Attr, "1">:$keepdims);
|
||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reduced);
|
||||
let builders = [
|
||||
OpBuilder<"Builder *builder, OperationState &state, Value data, ArrayAttr axes, IntegerAttr keepdims", [{
|
||||
auto elementType = data.getType().cast<TensorType>().getElementType();
|
||||
build(builder, state, UnrankedTensorType::get(elementType), data, axes, keepdims);
|
||||
}]>,
|
||||
OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||
std::vector<mlir::Type> outputTypes;
|
||||
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
|
||||
build(builder, state, outputTypes, operands, attributes);
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
def ONNXReluOp:ONNX_Op<"Relu",
|
||||
|
|
|
@ -17,252 +17,8 @@
|
|||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
// There are two ways to write rewrite rules:
|
||||
// - Declarative manner: specify rewrite rules in a TableGen record, and
|
||||
// - Manual Manner: subclass the mlir::RewritePattern.
|
||||
//
|
||||
// We prefer to use the former way as much as possible. However, there is a
|
||||
// limitation about operation definition specification (ODS) in TableGen that
|
||||
// requires us to write custom builders, that is
|
||||
// "all ODS-generated `build()` methods require specifying the result type(s),
|
||||
// unless the op has known traits like `SameOperandsAndResultType` that we can
|
||||
// use to auto-generate a `build()` method with result type deduction".
|
||||
//
|
||||
// More information about the limitation can be found here:
|
||||
// https://github.com/llvm/llvm-project/blob/master/mlir/docs/DeclarativeRewrites.md#building-operations
|
||||
//
|
||||
// Currently, we use the latter way of writing rewrite rules. There are two
|
||||
// reasons for this decision:
|
||||
// - To insert custom builders for operations, it is better to change the script
|
||||
// gen_doc.py to generate all possibles custom builders for a large class of
|
||||
// operations. At the time of this patch created, the gen_doc.py was changing,
|
||||
// so we decided to write manually to reduce conflicts.
|
||||
// - In declarative rewriting, we should deal with optional attributes. E.g. for
|
||||
// to handle optional attributes, but I haven't tried it yet.
|
||||
//
|
||||
// Once we have done the above issues, we will switch to use the declarative
|
||||
// manner.
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X)
|
||||
//===----------------------------------------------------------------------===//
|
||||
struct ReduceL1OpPattern : public RewritePattern {
|
||||
ReduceL1OpPattern(MLIRContext *context)
|
||||
: RewritePattern(ONNXReduceL1Op::getOperationName(),
|
||||
{ONNXAbsOp::getOperationName(),
|
||||
ONNXReduceSumOp::getOperationName()},
|
||||
1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto opInput = op->getOperands()[0]; // %X
|
||||
auto opResults = op->getResults();
|
||||
auto opAttrs = op->getAttrs();
|
||||
|
||||
// Rewrite
|
||||
ONNXAbsOp absOp;
|
||||
{
|
||||
auto elementType = opInput.getType().cast<TensorType>().getElementType();
|
||||
absOp = rewriter.create<ONNXAbsOp>(
|
||||
loc, UnrankedTensorType::get(elementType), opInput);
|
||||
}
|
||||
|
||||
ONNXReduceSumOp sumOp;
|
||||
{
|
||||
SmallVector<Type, 4> types;
|
||||
for (auto v : opResults) {
|
||||
types.emplace_back(v.getType());
|
||||
}
|
||||
|
||||
SmallVector<Value, 1> values;
|
||||
values.emplace_back(absOp.getResult());
|
||||
|
||||
SmallVector<NamedAttribute, 4> attrs;
|
||||
for (auto attr : opAttrs) {
|
||||
attrs.emplace_back(attr);
|
||||
}
|
||||
|
||||
sumOp = rewriter.create<ONNXReduceSumOp>(loc, types, values, attrs);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, sumOp.getResult());
|
||||
return matchSuccess();
|
||||
};
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X))
|
||||
//===----------------------------------------------------------------------===//
|
||||
struct ReduceL2OpPattern : public RewritePattern {
|
||||
ReduceL2OpPattern(MLIRContext *context)
|
||||
: RewritePattern(ONNXReduceL2Op::getOperationName(),
|
||||
{ONNXSqrtOp::getOperationName(),
|
||||
ONNXReduceSumSquareOp::getOperationName()},
|
||||
1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto opInput = op->getOperands()[0]; // %X
|
||||
auto opResults = op->getResults();
|
||||
auto opAttrs = op->getAttrs();
|
||||
|
||||
// Rewrite
|
||||
ONNXReduceSumSquareOp sumSquareOp;
|
||||
{
|
||||
auto elementType = opInput.getType().cast<TensorType>().getElementType();
|
||||
sumSquareOp = rewriter.create<ONNXReduceSumSquareOp>(
|
||||
loc, UnrankedTensorType::get(elementType), opInput, opAttrs);
|
||||
}
|
||||
|
||||
ONNXSqrtOp sqrtOp;
|
||||
{
|
||||
SmallVector<Type, 4> types;
|
||||
for (auto v : opResults) {
|
||||
types.emplace_back(v.getType());
|
||||
}
|
||||
sqrtOp = rewriter.create<ONNXSqrtOp>(loc, types, sumSquareOp.getResult());
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, sqrtOp.getResult());
|
||||
return matchSuccess();
|
||||
};
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X))
|
||||
//===----------------------------------------------------------------------===//
|
||||
struct ReduceLogSumOpPattern : public RewritePattern {
|
||||
ReduceLogSumOpPattern(MLIRContext *context)
|
||||
: RewritePattern(ONNXReduceLogSumOp::getOperationName(),
|
||||
{ONNXReduceSumOp::getOperationName(),
|
||||
ONNXLogOp::getOperationName()},
|
||||
1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto opInput = op->getOperands()[0]; // %X
|
||||
auto opResults = op->getResults();
|
||||
auto opAttrs = op->getAttrs();
|
||||
|
||||
// Rewrite
|
||||
ONNXReduceSumOp sumOp;
|
||||
{
|
||||
auto elementType = opInput.getType().cast<TensorType>().getElementType();
|
||||
sumOp = rewriter.create<ONNXReduceSumOp>(
|
||||
loc, UnrankedTensorType::get(elementType), opInput, opAttrs);
|
||||
}
|
||||
|
||||
ONNXLogOp logOp;
|
||||
{
|
||||
SmallVector<Type, 4> types;
|
||||
for (auto v : opResults) {
|
||||
types.emplace_back(v.getType());
|
||||
}
|
||||
logOp = rewriter.create<ONNXLogOp>(loc, types, sumOp.getResult());
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, logOp.getResult());
|
||||
return matchSuccess();
|
||||
};
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X)
|
||||
//===----------------------------------------------------------------------===//
|
||||
struct ReduceLogSumExpOpPattern : public RewritePattern {
|
||||
ReduceLogSumExpOpPattern(MLIRContext *context)
|
||||
: RewritePattern(ONNXReduceLogSumExpOp::getOperationName(),
|
||||
{ONNXExpOp::getOperationName(),
|
||||
ONNXReduceLogSumOp::getOperationName()},
|
||||
1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto opInput = op->getOperands()[0]; // %X
|
||||
auto opResults = op->getResults();
|
||||
auto opAttrs = op->getAttrs();
|
||||
|
||||
// Rewrite
|
||||
ONNXExpOp expOp;
|
||||
{
|
||||
auto elementType = opInput.getType().cast<TensorType>().getElementType();
|
||||
expOp = rewriter.create<ONNXExpOp>(
|
||||
loc, UnrankedTensorType::get(elementType), opInput);
|
||||
}
|
||||
|
||||
ONNXReduceLogSumOp logSumOp;
|
||||
{
|
||||
SmallVector<Type, 4> types;
|
||||
for (auto v : opResults) {
|
||||
types.emplace_back(v.getType());
|
||||
}
|
||||
|
||||
SmallVector<Value, 1> values;
|
||||
values.emplace_back(expOp.getResult());
|
||||
|
||||
SmallVector<NamedAttribute, 4> attrs;
|
||||
for (auto attr : opAttrs) {
|
||||
attrs.emplace_back(attr);
|
||||
}
|
||||
logSumOp = rewriter.create<ONNXReduceLogSumOp>(loc, types, values, attrs);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, logSumOp.getResult());
|
||||
return matchSuccess();
|
||||
};
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X)
|
||||
//===----------------------------------------------------------------------===//
|
||||
struct ReduceSumSquareOpPattern : public RewritePattern {
|
||||
ReduceSumSquareOpPattern(MLIRContext *context)
|
||||
: RewritePattern(ONNXReduceSumSquareOp::getOperationName(),
|
||||
{ONNXMulOp::getOperationName(),
|
||||
ONNXReduceSumOp::getOperationName()},
|
||||
1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto opInput = op->getOperands()[0]; // %X
|
||||
auto opResults = op->getResults();
|
||||
auto opAttrs = op->getAttrs();
|
||||
|
||||
// Rewrite
|
||||
ONNXMulOp mulOp;
|
||||
{
|
||||
auto elementType = opInput.getType().cast<TensorType>().getElementType();
|
||||
mulOp = rewriter.create<ONNXMulOp>(
|
||||
loc, UnrankedTensorType::get(elementType), opInput, opInput);
|
||||
}
|
||||
|
||||
ONNXReduceSumOp sumOp;
|
||||
{
|
||||
SmallVector<Type, 4> types;
|
||||
for (auto v : opResults) {
|
||||
types.emplace_back(v.getType());
|
||||
}
|
||||
|
||||
SmallVector<Value, 1> values;
|
||||
values.emplace_back(mulOp.getResult());
|
||||
|
||||
SmallVector<NamedAttribute, 4> attrs;
|
||||
for (auto attr : opAttrs) {
|
||||
attrs.emplace_back(attr);
|
||||
}
|
||||
sumOp = rewriter.create<ONNXReduceSumOp>(loc, types, values, attrs);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, sumOp.getResult());
|
||||
return matchSuccess();
|
||||
};
|
||||
};
|
||||
/// Include the patterns defined in the Declarative Rewrite framework.
|
||||
#include "src/onnx_rewrite.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Rewrite:
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//=- onnx_rewrite.td - Pattern Match Rewriting for ONNX -*- tablegen -*----===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// Defines language-specific pattern match optimizations for ONNX using
|
||||
// Declarative Rewrite Rules (DRR) specified using TableGen records.
|
||||
//
|
||||
|
||||
#ifndef ONNX_REWRITE
|
||||
#define ONNX_REWRITE
|
||||
|
||||
#ifndef OP_BASE
|
||||
include "dialect/onnx/onnx.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
/// Note: The DRR definition used for defining patterns is shown below:
|
||||
///
|
||||
/// class Pattern<
|
||||
/// dag sourcePattern, list<dag> resultPatterns,
|
||||
/// list<dag> additionalConstraints = [],
|
||||
/// dag benefitsAdded = (addBenefit 0)
|
||||
/// >;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X)
|
||||
//===----------------------------------------------------------------------===//
|
||||
def ReduceL1OpPattern: Pat<(ONNXReduceL1Op $oprd, $axes, $keepdims),
|
||||
(ONNXReduceSumOp (ONNXAbsOp $oprd), $axes, $keepdims)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X))
|
||||
//===----------------------------------------------------------------------===//
|
||||
def ReduceL2OpPattern: Pat<(ONNXReduceL2Op $oprd, $axes, $keepdims),
|
||||
(ONNXSqrtOp (ONNXReduceSumSquareOp $oprd, $axes, $keepdims))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X))
|
||||
//===----------------------------------------------------------------------===//
|
||||
def ReduceLogSumOpPattern: Pat<(ONNXReduceLogSumOp $oprd, $axes, $keepdims),
|
||||
(ONNXLogOp (ONNXReduceSumOp $oprd, $axes, $keepdims))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X)
|
||||
//===----------------------------------------------------------------------===//
|
||||
def ReduceLogSumExpOpPattern: Pat<(ONNXReduceLogSumExpOp $oprd, $axes, $keepdims),
|
||||
(ONNXReduceLogSumOp (ONNXExpOp $oprd), $axes, $keepdims)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X)
|
||||
//===----------------------------------------------------------------------===//
|
||||
def ReduceSumSquareOpPattern: Pat<(ONNXReduceSumSquareOp $oprd, $axes, $keepdims),
|
||||
(ONNXReduceSumOp (ONNXMulOp $oprd, $oprd), $axes, $keepdims)>;
|
||||
|
||||
#endif // ONNX_REWRITE
|
Loading…
Reference in New Issue