Add result type inference to op definition (#87)
* Add result type inference to op definition * Edit MLIR tests * Fix result type for Mul * Format comments * Return UnrankedTensorType as result type * Just for testing -split-input-file * Undo: Just for testing -split-input-file * Extract a function, get_operand_ins, that gets operand types; rewrite gen_attr_ins function * Generate custom builders * Call existing build methods * Add comments * Minor changes * Generate build methods with attributes * Add support of variadic type * Do not generate custom build methods for ops having only attributes Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
aea6479ad3
commit
479dd5e35a
184
doc/gen_doc.py
184
doc/gen_doc.py
|
@ -54,6 +54,15 @@ ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
|
||||||
CanonicalList=['Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum',
|
CanonicalList=['Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum',
|
||||||
'ReduceLogSumExp', 'ReduceSumSquare']
|
'ReduceLogSumExp', 'ReduceSumSquare']
|
||||||
|
|
||||||
|
#add an Op in this list if the Op needs result type deduction which is required
|
||||||
|
#when writing declarative rewriting rules. Deduced type is always
|
||||||
|
#an UnrankedTensorType whose element type is the same as the first operand's
|
||||||
|
#element type.
|
||||||
|
#currenlty, there are only two build methods generated:
|
||||||
|
# - one with operands and attributes having a separate parameter, and
|
||||||
|
# - one with operands and attributes having aggregated parameters.
|
||||||
|
custom_builder_ops_list = ['Abs', 'Mul', 'Exp', 'ReduceSum', 'ReduceSumSquare']
|
||||||
|
|
||||||
manual_code_in_op_def = dict([
|
manual_code_in_op_def = dict([
|
||||||
('DummyExample', ' let extraClassDeclaration = [{ \n'+
|
('DummyExample', ' let extraClassDeclaration = [{ \n'+
|
||||||
' static StringRef getPermAttrName() { return "perm"; }\n'+
|
' static StringRef getPermAttrName() { return "perm"; }\n'+
|
||||||
|
@ -345,38 +354,23 @@ def gen_schema(schema) :
|
||||||
#input
|
#input
|
||||||
s+= '\n'+line_indent+'let arguments = (ins '
|
s+= '\n'+line_indent+'let arguments = (ins '
|
||||||
isfirst = True
|
isfirst = True
|
||||||
if schema.inputs:
|
# add operands
|
||||||
isfirst = False
|
operand_ins = get_operand_ins(schema)
|
||||||
for input in schema.inputs:
|
for operand_type, operand_name in operand_ins:
|
||||||
if input != schema.inputs[0] :
|
if not isfirst:
|
||||||
s+= ',\n '
|
s+= ',\n '
|
||||||
etypes=collect_types(schema, input)
|
else:
|
||||||
|
isfirst = False
|
||||||
|
s+=operand_type+':$'+operand_name
|
||||||
|
|
||||||
if OpSchema.FormalParameterOption.Optional == input.option:
|
# add attributes
|
||||||
#TODO: handle optional
|
attr_ins = get_attr_ins(schema)
|
||||||
print("warning: optional input for"+schema.name+' '+input.name)
|
for attr_type, attr_name in attr_ins:
|
||||||
elif OpSchema.FormalParameterOption.Variadic == input.option:
|
if not isfirst:
|
||||||
if input.isHomogeneous:
|
s += ',\n '
|
||||||
s+= 'Variadic<'
|
else :
|
||||||
else:
|
isfirst = False
|
||||||
#TODO handle (variadic, heterogeneous)"
|
s += attr_type+':$'+attr_name
|
||||||
print("warning: (variadic, heterogeneous) for"+schema.name+' '+input.name)
|
|
||||||
if etypes == '':
|
|
||||||
s+= 'AnyTypeOf<[AnyMemRef, AnyTensor]>'
|
|
||||||
else:
|
|
||||||
s+= 'TensorOf<['+etypes+']>'
|
|
||||||
|
|
||||||
if OpSchema.FormalParameterOption.Optional == input.option:
|
|
||||||
#TODO: handle optional
|
|
||||||
t=''
|
|
||||||
elif OpSchema.FormalParameterOption.Variadic == input.option:
|
|
||||||
if input.isHomogeneous:
|
|
||||||
s+= '>'
|
|
||||||
else:
|
|
||||||
#TODO handle (variadic, heterogeneous)"
|
|
||||||
t=''
|
|
||||||
s+=':$'+input.name
|
|
||||||
s += gen_attr_ins(schema, isfirst)
|
|
||||||
s+= ');'
|
s+= ');'
|
||||||
|
|
||||||
#output
|
#output
|
||||||
|
@ -395,6 +389,71 @@ def gen_schema(schema) :
|
||||||
s+= ');\n'
|
s+= ');\n'
|
||||||
|
|
||||||
#s+= 'let hasCanonicalizer = 1;'
|
#s+= 'let hasCanonicalizer = 1;'
|
||||||
|
|
||||||
|
#TODO: any better way to do this.
|
||||||
|
def get_attr_type_for_builder(attr_type) :
|
||||||
|
if 'I64Attr' in attr_type :
|
||||||
|
mytype = 'IntegerAttr'
|
||||||
|
elif 'F32Attr' in attr_type :
|
||||||
|
mytype = 'FloatAttr'
|
||||||
|
elif 'I64ArrayAttr' in attr_type or 'F32ArrayAttr' in attr_type:
|
||||||
|
mytype = 'ArrayAttr'
|
||||||
|
elif 'StrAttr' in attr_type :
|
||||||
|
mytype = 'StringAttr'
|
||||||
|
elif 'strings' in attr_type :
|
||||||
|
mytype = 'ArrayAttr'
|
||||||
|
else :
|
||||||
|
mytype ='Attribute'
|
||||||
|
return mytype
|
||||||
|
|
||||||
|
def get_op_type_for_builder(op_type):
|
||||||
|
if op_type.startswith('Variadic'):
|
||||||
|
mytype = 'ValueRange'
|
||||||
|
else:
|
||||||
|
mytype = 'Value'
|
||||||
|
return mytype
|
||||||
|
|
||||||
|
# add custom builders
|
||||||
|
# use element type of the first operand to construct an UnrankedTensorType for the output.
|
||||||
|
if schema.name in custom_builder_ops_list:
|
||||||
|
if len(operand_ins) == 0:
|
||||||
|
print("warning: not generate custom build methods for " + schema.name + " since it does not have operands.")
|
||||||
|
else:
|
||||||
|
if get_op_type_for_builder(operand_ins[0][0]) == 'ValueRange':
|
||||||
|
first_operand = operand_ins[0][1]+'[0]'
|
||||||
|
else:
|
||||||
|
first_operand = operand_ins[0][1]
|
||||||
|
|
||||||
|
s += line_indent+'let builders = [\n'
|
||||||
|
|
||||||
|
# custom builders with operands and attributes having a seperate parameter.
|
||||||
|
# E.g. OpBuilder<"Builder *builder, OperationState &state, Value X, Value, Y, Attribute A", [{}]>
|
||||||
|
s += line_indent*2+'OpBuilder<"Builder *builder, OperationState &state'
|
||||||
|
for arg_type, arg_name in operand_ins:
|
||||||
|
s += ', '+get_op_type_for_builder(arg_type)+' '+arg_name
|
||||||
|
for attr_type, attr_name in attr_ins:
|
||||||
|
s += ', '+get_attr_type_for_builder(attr_type)+' '+attr_name
|
||||||
|
s += '", [{\n'
|
||||||
|
s += line_indent*3+'auto elementType = '+first_operand+'.getType().cast<TensorType>().getElementType();\n'
|
||||||
|
s += line_indent*3+'build(builder, state, UnrankedTensorType::get(elementType)'
|
||||||
|
for _, arg_name in operand_ins:
|
||||||
|
s += ', '+arg_name
|
||||||
|
for _, attr_name in attr_ins:
|
||||||
|
s += ', '+attr_name
|
||||||
|
s += ');\n'
|
||||||
|
s += line_indent*2+'}]>,\n'
|
||||||
|
|
||||||
|
# custom builders with all operands and attributes having aggregate parameters.
|
||||||
|
# E.g. OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{}]>'
|
||||||
|
s += line_indent*2+'OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{\n'
|
||||||
|
s += line_indent*3+'auto elementType = '+first_operand+'.getType().cast<TensorType>().getElementType();\n'
|
||||||
|
s += line_indent*3+'std::vector<mlir::Type> outputTypes;\n'
|
||||||
|
s += line_indent*3+'outputTypes.emplace_back(UnrankedTensorType::get(elementType));\n'
|
||||||
|
s += line_indent*3+'build(builder, state, outputTypes, operands, attributes);\n'
|
||||||
|
s += line_indent*2+'}]>'
|
||||||
|
|
||||||
|
s += '\n'+line_indent+'];\n'
|
||||||
|
|
||||||
#add special code
|
#add special code
|
||||||
if schema.name in manual_code_in_op_def :
|
if schema.name in manual_code_in_op_def :
|
||||||
s += manual_code_in_op_def[schema.name]
|
s += manual_code_in_op_def[schema.name]
|
||||||
|
@ -447,7 +506,41 @@ def gen_code(schema,fefile) :
|
||||||
else:
|
else:
|
||||||
fefile.write(', '+variadicIn+', '+variadicOut+');\n')
|
fefile.write(', '+variadicIn+', '+variadicOut+');\n')
|
||||||
|
|
||||||
def gen_attr_ins(schema, isfirst) :
|
def get_operand_ins(schema):
|
||||||
|
operand_type_and_name_list = [] # [(optype, opname)]
|
||||||
|
if schema.inputs:
|
||||||
|
for input in schema.inputs:
|
||||||
|
optype = ""
|
||||||
|
|
||||||
|
etypes=collect_types(schema, input)
|
||||||
|
|
||||||
|
if OpSchema.FormalParameterOption.Optional == input.option:
|
||||||
|
#TODO : handle optional
|
||||||
|
print("warning: optional input for"+schema.name+' '+input.name)
|
||||||
|
elif OpSchema.FormalParameterOption.Variadic == input.option:
|
||||||
|
if input.isHomogeneous:
|
||||||
|
optype += 'Variadic<'
|
||||||
|
else:
|
||||||
|
#TODO handle(variadic, heterogeneous) "
|
||||||
|
print("warning: (variadic, heterogeneous) for"+schema.name+' '+input.name)
|
||||||
|
if etypes == '':
|
||||||
|
optype += 'AnyTypeOf<[AnyMemRef, AnyTensor]>'
|
||||||
|
else:
|
||||||
|
optype += 'TensorOf<['+etypes+']>'
|
||||||
|
|
||||||
|
if OpSchema.FormalParameterOption.Optional == input.option:
|
||||||
|
#TODO : handle optional
|
||||||
|
t=''
|
||||||
|
elif OpSchema.FormalParameterOption.Variadic == input.option:
|
||||||
|
if input.isHomogeneous:
|
||||||
|
optype += '>'
|
||||||
|
else:
|
||||||
|
#TODO handle(variadic, heterogeneous) "
|
||||||
|
t=''
|
||||||
|
operand_type_and_name_list.append((optype, input.name))
|
||||||
|
return operand_type_and_name_list
|
||||||
|
|
||||||
|
def get_attr_ins(schema) :
|
||||||
|
|
||||||
def get_attr_type_basic(attr_type) :
|
def get_attr_type_basic(attr_type) :
|
||||||
if attr_type == 'int' :
|
if attr_type == 'int' :
|
||||||
|
@ -479,24 +572,22 @@ def gen_attr_ins(schema, isfirst) :
|
||||||
mytype += ', "'+attr_default+'">'
|
mytype += ', "'+attr_default+'">'
|
||||||
return mytype
|
return mytype
|
||||||
|
|
||||||
|
attr_type_and_name_list = [] # :: [(attrtype, attrname)]
|
||||||
attr_line = ''
|
attr_line = ''
|
||||||
if schema.attributes:
|
if schema.attributes:
|
||||||
for _, attr in sorted(schema.attributes.items()):
|
for _, attr in sorted(schema.attributes.items()):
|
||||||
#attr_line = line_indent+line_indent+line_indent+line_indent
|
#attr_line = line_indent+line_indent+line_indent+line_indent
|
||||||
if not isfirst:
|
found = False
|
||||||
attr_line += ',\n '
|
attr_type = ""
|
||||||
else :
|
|
||||||
isfirst = False
|
|
||||||
|
|
||||||
if schema.name+' '+attr.name in special_attr_defaults:
|
if schema.name+' '+attr.name in special_attr_defaults:
|
||||||
(attr_type_str, attr_default_str) = special_attr_defaults[schema.name+' '+attr.name]
|
(attr_type_str, attr_default_str) = special_attr_defaults[schema.name+' '+attr.name]
|
||||||
attr_line += get_attr_type_with_default(attr_type_str, attr_default_str)
|
attr_type = get_attr_type_with_default(attr_type_str, attr_default_str)
|
||||||
attr_line += ':$'+attr.name
|
found = True
|
||||||
elif attr.required:
|
elif attr.required:
|
||||||
s = Text(attr.type)
|
s = Text(attr.type)
|
||||||
attr_type_str = s[s.rfind('.') + 1:].lower()
|
attr_type_str = s[s.rfind('.') + 1:].lower()
|
||||||
attr_line += get_attr_type_basic(attr_type_str)
|
attr_type = get_attr_type_basic(attr_type_str)
|
||||||
attr_line += ':$'+attr.name
|
found = True
|
||||||
|
|
||||||
# option holds either required or default value
|
# option holds either required or default value
|
||||||
elif attr.default_value.name:
|
elif attr.default_value.name:
|
||||||
|
@ -527,14 +618,15 @@ def gen_attr_ins(schema, isfirst) :
|
||||||
else:
|
else:
|
||||||
default_value = format_value(default_value)
|
default_value = format_value(default_value)
|
||||||
attr_option_str = default_value
|
attr_option_str = default_value
|
||||||
attr_line += get_attr_type_with_default(attr_type_str, attr_option_str)
|
attr_type = get_attr_type_with_default(attr_type_str, attr_option_str)
|
||||||
attr_line += ':$'+attr.name
|
found = True
|
||||||
else:
|
else:
|
||||||
s = Text(attr.type)
|
s = Text(attr.type)
|
||||||
attr_type_str = s[s.rfind('.') + 1:].lower()
|
attr_type_str = s[s.rfind('.') + 1:].lower()
|
||||||
attr_line += get_attr_type_optional(attr_type_str)
|
attr_type = get_attr_type_optional(attr_type_str)
|
||||||
attr_line += ':$'+attr.name
|
if found:
|
||||||
return attr_line
|
attr_type_and_name_list.append((attr_type, attr.name))
|
||||||
|
return attr_type_and_name_list
|
||||||
|
|
||||||
def main(args): # type: (Type[Args]) -> None
|
def main(args): # type: (Type[Args]) -> None
|
||||||
with io.open(args.changelog, 'w', newline='') as fout:
|
with io.open(args.changelog, 'w', newline='') as fout:
|
||||||
|
|
|
@ -36,6 +36,11 @@ onnf_tablegen(onnx_combine.inc -gen-rewriters)
|
||||||
add_public_tablegen_target(gen_onnx_combine)
|
add_public_tablegen_target(gen_onnx_combine)
|
||||||
add_dependencies(compiler gen_onnx_combine)
|
add_dependencies(compiler gen_onnx_combine)
|
||||||
|
|
||||||
|
set(LLVM_TARGET_DEFINITIONS pass/onnx_rewrite.td)
|
||||||
|
onnf_tablegen(onnx_rewrite.inc -gen-rewriters)
|
||||||
|
add_public_tablegen_target(gen_onnx_rewrite)
|
||||||
|
add_dependencies(compiler gen_onnx_rewrite)
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
|
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
|
||||||
onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||||
onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||||
|
|
|
@ -14,6 +14,18 @@ def ONNXAbsOp:ONNX_Op<"Abs",
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X);
|
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"Builder *builder, OperationState &state, Value X", [{
|
||||||
|
auto elementType = X.getType().cast<TensorType>().getElementType();
|
||||||
|
build(builder, state, UnrankedTensorType::get(elementType), X);
|
||||||
|
}]>,
|
||||||
|
OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
|
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||||
|
std::vector<mlir::Type> outputTypes;
|
||||||
|
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
|
||||||
|
build(builder, state, outputTypes, operands, attributes);
|
||||||
|
}]>
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXAcosOp:ONNX_Op<"Acos",
|
def ONNXAcosOp:ONNX_Op<"Acos",
|
||||||
|
@ -649,6 +661,18 @@ def ONNXExpOp:ONNX_Op<"Exp",
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input);
|
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"Builder *builder, OperationState &state, Value input", [{
|
||||||
|
auto elementType = input.getType().cast<TensorType>().getElementType();
|
||||||
|
build(builder, state, UnrankedTensorType::get(elementType), input);
|
||||||
|
}]>,
|
||||||
|
OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
|
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||||
|
std::vector<mlir::Type> outputTypes;
|
||||||
|
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
|
||||||
|
build(builder, state, outputTypes, operands, attributes);
|
||||||
|
}]>
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXExpandOp:ONNX_Op<"Expand",
|
def ONNXExpandOp:ONNX_Op<"Expand",
|
||||||
|
@ -1763,6 +1787,18 @@ def ONNXMulOp:ONNX_Op<"Mul",
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A,
|
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A,
|
||||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B);
|
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$C);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$C);
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"Builder *builder, OperationState &state, Value A, Value B", [{
|
||||||
|
auto elementType = A.getType().cast<TensorType>().getElementType();
|
||||||
|
build(builder, state, UnrankedTensorType::get(elementType), A, B);
|
||||||
|
}]>,
|
||||||
|
OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
|
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||||
|
std::vector<mlir::Type> outputTypes;
|
||||||
|
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
|
||||||
|
build(builder, state, outputTypes, operands, attributes);
|
||||||
|
}]>
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXMultinomialOp:ONNX_Op<"Multinomial",
|
def ONNXMultinomialOp:ONNX_Op<"Multinomial",
|
||||||
|
@ -2431,6 +2467,18 @@ def ONNXReduceSumOp:ONNX_Op<"ReduceSum",
|
||||||
OptionalAttr<I64ArrayAttr>:$axes,
|
OptionalAttr<I64ArrayAttr>:$axes,
|
||||||
DefaultValuedAttr<I64Attr, "1">:$keepdims);
|
DefaultValuedAttr<I64Attr, "1">:$keepdims);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reduced);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reduced);
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"Builder *builder, OperationState &state, Value data, ArrayAttr axes, IntegerAttr keepdims", [{
|
||||||
|
auto elementType = data.getType().cast<TensorType>().getElementType();
|
||||||
|
build(builder, state, UnrankedTensorType::get(elementType), data, axes, keepdims);
|
||||||
|
}]>,
|
||||||
|
OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
|
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||||
|
std::vector<mlir::Type> outputTypes;
|
||||||
|
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
|
||||||
|
build(builder, state, outputTypes, operands, attributes);
|
||||||
|
}]>
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare",
|
def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare",
|
||||||
|
@ -2449,6 +2497,18 @@ def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare",
|
||||||
OptionalAttr<I64ArrayAttr>:$axes,
|
OptionalAttr<I64ArrayAttr>:$axes,
|
||||||
DefaultValuedAttr<I64Attr, "1">:$keepdims);
|
DefaultValuedAttr<I64Attr, "1">:$keepdims);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reduced);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reduced);
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"Builder *builder, OperationState &state, Value data, ArrayAttr axes, IntegerAttr keepdims", [{
|
||||||
|
auto elementType = data.getType().cast<TensorType>().getElementType();
|
||||||
|
build(builder, state, UnrankedTensorType::get(elementType), data, axes, keepdims);
|
||||||
|
}]>,
|
||||||
|
OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
|
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||||
|
std::vector<mlir::Type> outputTypes;
|
||||||
|
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
|
||||||
|
build(builder, state, outputTypes, operands, attributes);
|
||||||
|
}]>
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXReluOp:ONNX_Op<"Relu",
|
def ONNXReluOp:ONNX_Op<"Relu",
|
||||||
|
|
|
@ -17,252 +17,8 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
/// Include the patterns defined in the Declarative Rewrite framework.
|
||||||
// There are two ways to write rewrite rules:
|
#include "src/onnx_rewrite.inc"
|
||||||
// - Declarative manner: specify rewrite rules in a TableGen record, and
|
|
||||||
// - Manual Manner: subclass the mlir::RewritePattern.
|
|
||||||
//
|
|
||||||
// We prefer to use the former way as much as possible. However, there is a
|
|
||||||
// limitation about operation definition specification (ODS) in TableGen that
|
|
||||||
// requires us to write custom builders, that is
|
|
||||||
// "all ODS-generated `build()` methods require specifying the result type(s),
|
|
||||||
// unless the op has known traits like `SameOperandsAndResultType` that we can
|
|
||||||
// use to auto-generate a `build()` method with result type deduction".
|
|
||||||
//
|
|
||||||
// More information about the limitation can be found here:
|
|
||||||
// https://github.com/llvm/llvm-project/blob/master/mlir/docs/DeclarativeRewrites.md#building-operations
|
|
||||||
//
|
|
||||||
// Currently, we use the latter way of writing rewrite rules. There are two
|
|
||||||
// reasons for this decision:
|
|
||||||
// - To insert custom builders for operations, it is better to change the script
|
|
||||||
// gen_doc.py to generate all possibles custom builders for a large class of
|
|
||||||
// operations. At the time of this patch created, the gen_doc.py was changing,
|
|
||||||
// so we decided to write manually to reduce conflicts.
|
|
||||||
// - In declarative rewriting, we should deal with optional attributes. E.g. for
|
|
||||||
// to handle optional attributes, but I haven't tried it yet.
|
|
||||||
//
|
|
||||||
// Once we have done the above issues, we will switch to use the declarative
|
|
||||||
// manner.
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X)
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
struct ReduceL1OpPattern : public RewritePattern {
|
|
||||||
ReduceL1OpPattern(MLIRContext *context)
|
|
||||||
: RewritePattern(ONNXReduceL1Op::getOperationName(),
|
|
||||||
{ONNXAbsOp::getOperationName(),
|
|
||||||
ONNXReduceSumOp::getOperationName()},
|
|
||||||
1, context) {}
|
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(Operation *op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
auto loc = op->getLoc();
|
|
||||||
auto opInput = op->getOperands()[0]; // %X
|
|
||||||
auto opResults = op->getResults();
|
|
||||||
auto opAttrs = op->getAttrs();
|
|
||||||
|
|
||||||
// Rewrite
|
|
||||||
ONNXAbsOp absOp;
|
|
||||||
{
|
|
||||||
auto elementType = opInput.getType().cast<TensorType>().getElementType();
|
|
||||||
absOp = rewriter.create<ONNXAbsOp>(
|
|
||||||
loc, UnrankedTensorType::get(elementType), opInput);
|
|
||||||
}
|
|
||||||
|
|
||||||
ONNXReduceSumOp sumOp;
|
|
||||||
{
|
|
||||||
SmallVector<Type, 4> types;
|
|
||||||
for (auto v : opResults) {
|
|
||||||
types.emplace_back(v.getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value, 1> values;
|
|
||||||
values.emplace_back(absOp.getResult());
|
|
||||||
|
|
||||||
SmallVector<NamedAttribute, 4> attrs;
|
|
||||||
for (auto attr : opAttrs) {
|
|
||||||
attrs.emplace_back(attr);
|
|
||||||
}
|
|
||||||
|
|
||||||
sumOp = rewriter.create<ONNXReduceSumOp>(loc, types, values, attrs);
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, sumOp.getResult());
|
|
||||||
return matchSuccess();
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X))
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
struct ReduceL2OpPattern : public RewritePattern {
|
|
||||||
ReduceL2OpPattern(MLIRContext *context)
|
|
||||||
: RewritePattern(ONNXReduceL2Op::getOperationName(),
|
|
||||||
{ONNXSqrtOp::getOperationName(),
|
|
||||||
ONNXReduceSumSquareOp::getOperationName()},
|
|
||||||
1, context) {}
|
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(Operation *op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
auto loc = op->getLoc();
|
|
||||||
auto opInput = op->getOperands()[0]; // %X
|
|
||||||
auto opResults = op->getResults();
|
|
||||||
auto opAttrs = op->getAttrs();
|
|
||||||
|
|
||||||
// Rewrite
|
|
||||||
ONNXReduceSumSquareOp sumSquareOp;
|
|
||||||
{
|
|
||||||
auto elementType = opInput.getType().cast<TensorType>().getElementType();
|
|
||||||
sumSquareOp = rewriter.create<ONNXReduceSumSquareOp>(
|
|
||||||
loc, UnrankedTensorType::get(elementType), opInput, opAttrs);
|
|
||||||
}
|
|
||||||
|
|
||||||
ONNXSqrtOp sqrtOp;
|
|
||||||
{
|
|
||||||
SmallVector<Type, 4> types;
|
|
||||||
for (auto v : opResults) {
|
|
||||||
types.emplace_back(v.getType());
|
|
||||||
}
|
|
||||||
sqrtOp = rewriter.create<ONNXSqrtOp>(loc, types, sumSquareOp.getResult());
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, sqrtOp.getResult());
|
|
||||||
return matchSuccess();
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X))
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
struct ReduceLogSumOpPattern : public RewritePattern {
|
|
||||||
ReduceLogSumOpPattern(MLIRContext *context)
|
|
||||||
: RewritePattern(ONNXReduceLogSumOp::getOperationName(),
|
|
||||||
{ONNXReduceSumOp::getOperationName(),
|
|
||||||
ONNXLogOp::getOperationName()},
|
|
||||||
1, context) {}
|
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(Operation *op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
auto loc = op->getLoc();
|
|
||||||
auto opInput = op->getOperands()[0]; // %X
|
|
||||||
auto opResults = op->getResults();
|
|
||||||
auto opAttrs = op->getAttrs();
|
|
||||||
|
|
||||||
// Rewrite
|
|
||||||
ONNXReduceSumOp sumOp;
|
|
||||||
{
|
|
||||||
auto elementType = opInput.getType().cast<TensorType>().getElementType();
|
|
||||||
sumOp = rewriter.create<ONNXReduceSumOp>(
|
|
||||||
loc, UnrankedTensorType::get(elementType), opInput, opAttrs);
|
|
||||||
}
|
|
||||||
|
|
||||||
ONNXLogOp logOp;
|
|
||||||
{
|
|
||||||
SmallVector<Type, 4> types;
|
|
||||||
for (auto v : opResults) {
|
|
||||||
types.emplace_back(v.getType());
|
|
||||||
}
|
|
||||||
logOp = rewriter.create<ONNXLogOp>(loc, types, sumOp.getResult());
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, logOp.getResult());
|
|
||||||
return matchSuccess();
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X)
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
struct ReduceLogSumExpOpPattern : public RewritePattern {
|
|
||||||
ReduceLogSumExpOpPattern(MLIRContext *context)
|
|
||||||
: RewritePattern(ONNXReduceLogSumExpOp::getOperationName(),
|
|
||||||
{ONNXExpOp::getOperationName(),
|
|
||||||
ONNXReduceLogSumOp::getOperationName()},
|
|
||||||
1, context) {}
|
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(Operation *op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
auto loc = op->getLoc();
|
|
||||||
auto opInput = op->getOperands()[0]; // %X
|
|
||||||
auto opResults = op->getResults();
|
|
||||||
auto opAttrs = op->getAttrs();
|
|
||||||
|
|
||||||
// Rewrite
|
|
||||||
ONNXExpOp expOp;
|
|
||||||
{
|
|
||||||
auto elementType = opInput.getType().cast<TensorType>().getElementType();
|
|
||||||
expOp = rewriter.create<ONNXExpOp>(
|
|
||||||
loc, UnrankedTensorType::get(elementType), opInput);
|
|
||||||
}
|
|
||||||
|
|
||||||
ONNXReduceLogSumOp logSumOp;
|
|
||||||
{
|
|
||||||
SmallVector<Type, 4> types;
|
|
||||||
for (auto v : opResults) {
|
|
||||||
types.emplace_back(v.getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value, 1> values;
|
|
||||||
values.emplace_back(expOp.getResult());
|
|
||||||
|
|
||||||
SmallVector<NamedAttribute, 4> attrs;
|
|
||||||
for (auto attr : opAttrs) {
|
|
||||||
attrs.emplace_back(attr);
|
|
||||||
}
|
|
||||||
logSumOp = rewriter.create<ONNXReduceLogSumOp>(loc, types, values, attrs);
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, logSumOp.getResult());
|
|
||||||
return matchSuccess();
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X)
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
struct ReduceSumSquareOpPattern : public RewritePattern {
|
|
||||||
ReduceSumSquareOpPattern(MLIRContext *context)
|
|
||||||
: RewritePattern(ONNXReduceSumSquareOp::getOperationName(),
|
|
||||||
{ONNXMulOp::getOperationName(),
|
|
||||||
ONNXReduceSumOp::getOperationName()},
|
|
||||||
1, context) {}
|
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(Operation *op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
auto loc = op->getLoc();
|
|
||||||
auto opInput = op->getOperands()[0]; // %X
|
|
||||||
auto opResults = op->getResults();
|
|
||||||
auto opAttrs = op->getAttrs();
|
|
||||||
|
|
||||||
// Rewrite
|
|
||||||
ONNXMulOp mulOp;
|
|
||||||
{
|
|
||||||
auto elementType = opInput.getType().cast<TensorType>().getElementType();
|
|
||||||
mulOp = rewriter.create<ONNXMulOp>(
|
|
||||||
loc, UnrankedTensorType::get(elementType), opInput, opInput);
|
|
||||||
}
|
|
||||||
|
|
||||||
ONNXReduceSumOp sumOp;
|
|
||||||
{
|
|
||||||
SmallVector<Type, 4> types;
|
|
||||||
for (auto v : opResults) {
|
|
||||||
types.emplace_back(v.getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value, 1> values;
|
|
||||||
values.emplace_back(mulOp.getResult());
|
|
||||||
|
|
||||||
SmallVector<NamedAttribute, 4> attrs;
|
|
||||||
for (auto attr : opAttrs) {
|
|
||||||
attrs.emplace_back(attr);
|
|
||||||
}
|
|
||||||
sumOp = rewriter.create<ONNXReduceSumOp>(loc, types, values, attrs);
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, sumOp.getResult());
|
|
||||||
return matchSuccess();
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Rewrite:
|
// Rewrite:
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//=- onnx_rewrite.td - Pattern Match Rewriting for ONNX -*- tablegen -*----===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// Defines language-specific pattern match optimizations for ONNX using
|
||||||
|
// Declarative Rewrite Rules (DRR) specified using TableGen records.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef ONNX_REWRITE
|
||||||
|
#define ONNX_REWRITE
|
||||||
|
|
||||||
|
#ifndef OP_BASE
|
||||||
|
include "dialect/onnx/onnx.td"
|
||||||
|
#endif // OP_BASE
|
||||||
|
|
||||||
|
/// Note: The DRR definition used for defining patterns is shown below:
|
||||||
|
///
|
||||||
|
/// class Pattern<
|
||||||
|
/// dag sourcePattern, list<dag> resultPatterns,
|
||||||
|
/// list<dag> additionalConstraints = [],
|
||||||
|
/// dag benefitsAdded = (addBenefit 0)
|
||||||
|
/// >;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X)
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
def ReduceL1OpPattern: Pat<(ONNXReduceL1Op $oprd, $axes, $keepdims),
|
||||||
|
(ONNXReduceSumOp (ONNXAbsOp $oprd), $axes, $keepdims)>;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X))
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
def ReduceL2OpPattern: Pat<(ONNXReduceL2Op $oprd, $axes, $keepdims),
|
||||||
|
(ONNXSqrtOp (ONNXReduceSumSquareOp $oprd, $axes, $keepdims))>;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X))
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
def ReduceLogSumOpPattern: Pat<(ONNXReduceLogSumOp $oprd, $axes, $keepdims),
|
||||||
|
(ONNXLogOp (ONNXReduceSumOp $oprd, $axes, $keepdims))>;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X)
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
def ReduceLogSumExpOpPattern: Pat<(ONNXReduceLogSumExpOp $oprd, $axes, $keepdims),
|
||||||
|
(ONNXReduceLogSumOp (ONNXExpOp $oprd), $axes, $keepdims)>;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X)
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
def ReduceSumSquareOpPattern: Pat<(ONNXReduceSumSquareOp $oprd, $axes, $keepdims),
|
||||||
|
(ONNXReduceSumOp (ONNXMulOp $oprd, $oprd), $axes, $keepdims)>;
|
||||||
|
|
||||||
|
#endif // ONNX_REWRITE
|
Loading…
Reference in New Issue